Skip to content

feat(deepseek_v4): PR1 skeleton — end-to-end inference with triton MoE#650

Merged
valarLip merged 48 commits intomainfrom
feat/deepseek-v4-pr1-skeleton
May 6, 2026
Merged

feat(deepseek_v4): PR1 skeleton — end-to-end inference with triton MoE#650
valarLip merged 48 commits intomainfrom
feat/deepseek-v4-pr1-skeleton

Conversation

@valarLip
Copy link
Copy Markdown
Collaborator

@valarLip valarLip commented Apr 25, 2026

Summary

DeepSeek-V4 end-to-end inference on real checkpoint (/data/DeepSeek-V4-Pro, TP=8). Now covers:

  • PR1 (skeleton): full V4 architecture + triton MoE + standard ATOM loader
  • pre2a / pre2c-A / pre2c-B: per-request state cache abstraction + classical KV cache via block_table (per paper §3.6.1)
  • PR3-main: multi-sequence batched dispatch via per-seq Python loop
  • PR-A (Phase 0/1/2 partial): backend gate scaffold + CPU-mirror metadata + swa_write Triton kernel + update_compressor_states Triton kernel (pos % (2*ratio) ring buffer per paper §3.6.1 + eq 11)
  • fused_compress_attn: PR4 fused boundary compress kernel (RMSNorm + GPT-J RoPE + ape softmax-pool inside the Triton K-loop)
  • batched compressor (SGLang plan): compressor / update_compressor_states go from per-seq Python loop (64 layers × num_seqs GPU launches/fwd) to ONE batched kernel call per layer, driven by 16B plan rows [ragged_id, batch_id, position, window_len]
  • batched indexer (Phase 2b-i/ii): Indexer.forward_batched — single fp8_mqa_logits + topk + width/future mask + offset across all seqs, replacing per-seq Python loop. Hoisted Indexer Compressor out of dispatch loop (one batched call instead of 64 layers × num_seqs)
  • Phase 3 hoist (0eb32a92): all per-fwd-invariant metadata moved from per-layer forward to attn_metadata_builder. Eliminated ~1200 per-layer torch.as_tensor H2D copies in production fast path. Removed dummy_run special path — warmup goes through normal forward. Per-(ratio) sparse pack indices, indexer broadcast helpers, gather indices, layer-invariant GPU tensors (cu_starts/cu_ends/visible_end/width_mask/future_threshold) all built once per fwd in builder
  • ROCm hotfix: sparse_attn_v4 triton kernels' BLOCK_H raised to 16 to satisfy gfx9xx/gfx950 MFMA tile minimum (was crashing PR feat(deepseek_v4): use triton sparse attn kernel and move attn kernel out of loop #678's TritonAMDGPUOptimizeDotOperands pass at JIT compile)

Verified on real ckpt: single-seq + 4-batch, English + Chinese, coherent outputs across all slots. GSM8K 100-sample at 0.96 (3-shot, flexible-extract) — see Accuracy below.

Reproduce — offline 4-prompt batch

# Prerequisites
pip install -e /triton-test/python/triton_kernels/

ATOM_USE_TRITON_MOE=1 AITER_LOG_LEVEL=WARNING \
python -m atom.examples.simple_inference \
  --model /data/DeepSeek-V4-Pro \
  --kv_cache_dtype fp8 \
  -tp 8 \
  --max-num-seqs 4 \
  --max-num-batched-tokens 1024 \
  --max-model-len 1024 \
  --gpu-memory-utilization 0.85 \
  --enforce-eager \
  --temperature 0.0 \
  --max-tokens 256

End-to-end 178s for 4 prompts × 256 tokens (eager mode, no CUDAGraph yet, no AITER GEMM tuning for V4 shapes). TTFT ≈ 1.0s on a 10-token prefill.

Sample output (4-prompt batch, temperature=0, max_tokens=256)

Prompt: introduce yourself
Completion: Hello! 👋 I'm DeepSeek, an AI assistant created by DeepSeek
(深度求索) company. I'm here to help you with a wide range of tasks!
[...256 tokens, eos]

Prompt: list all prime numbers within 100
Completion: Here is the list of all prime numbers less than 100:
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
73, 79, 83, 89, 97
There are 25 prime numbers in total.

Prompt: 1+2+3=?
Completion: That's a simple addition problem.
1 + 2 + 3 = **6**

Prompt: 如何在一个月内增肌10公斤
Completion: 一个月内增肌10公斤是一个非常极端且**几乎不可能实现**的目标,
除非你处于以下几种特殊情况:[...256 tokens, max_tokens]

Accuracy — GSM8K (100-sample, 3-shot, flexible-extract / strict-match)

Runs against the live OpenAI-compatible server. Validates Phase 3 hoist + batched indexer/compressor end-to-end.

Step 1: start the server (TP=8, fp8 KV, enforce-eager).

# Clean stale procs/cache before relaunch
pkill -9 -f 'atom.entrypoints' 2>/dev/null; sleep 3
pkill -9 -f 'multiprocessing.spawn' 2>/dev/null; sleep 3
rm -rf /root/.cache/atom/*

ATOM_USE_TRITON_MOE=1 AITER_LOG_LEVEL=WARNING \
python -m atom.entrypoints.openai_server \
  --model /data/DeepSeek-V4-Pro \
  --kv_cache_dtype fp8 \
  -tp 8 \
  --max-num-seqs 16 \
  --gpu-memory-utilization 0.85 \
  --server-port 8000 \
  --enforce-eager \
  --max-model-len 4096 \
  > /tmp/atom_server.log 2>&1 &

# Wait for ready (model loaded + GPU VRAM > 0)
for i in $(seq 1 60); do
  if curl -sf http://localhost:8000/v1/models > /dev/null 2>&1; then
    echo "Server ready"; break
  fi
  sleep 5
done

Step 2: run lm_eval against the server.

lm_eval --model local-completions \
  --model_args "model=/data/DeepSeek-V4-Pro,base_url=http://localhost:8000/v1/completions,num_concurrent=16,max_retries=2,tokenized_requests=False,trust_remote_code=True" \
  --tasks gsm8k \
  --num_fewshot 3 \
  --limit 100

Result (3:05 wall-clock, ~1.8 s/sample at concurrency 16):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match|↑  | 0.96|±  |0.0197|
|     |       |strict-match    |     3|exact_match|↑  | 0.96|±  |0.0197|

96/100 correct. Stable 0.94–0.96 across 50/100-sample bisects during Phase 3 incremental landings.

Accuracy progression across landings

Landing Commit Sample Score Notes
PR1 skeleton smoke early 25 / 3-shot 0.88 initial validation
Batched compressor 352338df 50 / 3-shot 0.96 SGLang plan tensors
Batched indexer (b-i + ii) caed0c7f 50 / 3-shot 0.96 fp8_mqa_logits + topk batched
Phase 3 hoist + cleanup 0eb32a92 100 / 3-shot 0.96 ± 0.020 metadata builder hoist, no per-layer H2D
top_k_per_row prefill+decode 55be12a2 100 / 3-shot 0.97 ± 0.017 aiter radix kernel adoption (depends on ROCm/aiter#3012)
CUDAGraph (paged-decode kernel + sentinel padding) cb7f84f1 50 / 3-shot 0.98 ± 0.020 wider CG capture [1,2,4,8,16,32,64] + max_num_seqs=64

CUDAGraph (production decode path)

The eager-mode workaround (--enforce-eager) is no longer required. Enable
CUDAGraph capture by passing --cudagraph-capture-sizes explicitly (must
be a quoted Python list literal — bare 1 2 4 will be rejected by argparse
as unrecognized arguments). The default cuda_graph_sizes=[512] requires
max_num_seqs ≥ 512, which is rarely what you want — always pass
--cudagraph-capture-sizes for the actual decode-batch range your workload
hits.

pkill -9 -f 'atom.entrypoints' 2>/dev/null; sleep 3
pkill -9 -f 'multiprocessing.spawn' 2>/dev/null; sleep 3
rm -rf /root/.cache/atom/*

ATOM_USE_TRITON_MOE=1 AITER_LOG_LEVEL=WARNING \
python -m atom.entrypoints.openai_server \
  --model /data/DeepSeek-V4-Pro \
  --kv_cache_dtype fp8 \
  -tp 8 \
  --max-num-seqs 64 \
  --gpu-memory-utilization 0.85 \
  --server-port 8000 \
  --level 0 \
  --cudagraph-capture-sizes "[1,2,4,8,16,32,64]" \
  --max-model-len 4096 \
  > /tmp/atom_server.log 2>&1 &

# Wait for server ready (model loaded + GPU VRAM > 0)
for i in $(seq 1 60); do
  if curl -sf http://localhost:8000/v1/models > /dev/null 2>&1; then
    echo "Server ready"; break
  fi
  sleep 5
done

Capture itself takes ~2.7s for the 7-size sweep above. Decode replay path
on a typical workload sees a ~4.3× TPOT speedup vs eager. lm_eval (50/100
samples, 3-shot) returns 0.96–0.98 — same accuracy band as eager.

How it works

V4's decode dispatch (is_pure_decode = uniform tokens-per-seq AND no fresh prefill, doc §7.4) routes into a custom triton kernel
sparse_attn_v4_paged_decode (atom/model_ops/v4_kernels/paged_decode.py,
page_size = 1, with V4's per-head learnable attn_sink). Per layer, SWA
ring + compressor paged KV are physically merged into a single BF16
unified_kv pool — the kernel carries one base pointer; every index
(SWA / CSA / HCA) is a row offset into that pool. Index buffers
(kv_indices_{swa,csa,hca} + kv_indptr_{swa,csa,hca}) are constructed
once per fwd in stable forward_vars storage with metadata-time
constants (no .item(), no device-data-dependent allocation). CSA's
per-layer indexer-output translation goes through a fixed-grid triton
kernel csa_packed_write (atom/model_ops/v4_kernels/csa_packed_write.py).

The captured graph is invariant under actual_bs ≤ graph_bs padding
because the builder sentinel-pads metadata to the captured slot count:
per-token tensors carry -1 (consumer kernels skip on bid<0 /
src_id<0); indptr cumsums repeat the last value (kernel sees kv_len=0
and bails before any read). Without this, captured kernels read stale
buffer entries at replay → kv_indices OOB → Memory access fault by GPU.

Full design and rationale (4 layout / dispatch refactors vs upstream V4
reference, KV pool layout, MTP-1 indptr replication, 11 CG-friendliness
constraints, sentinel protocol, H2D dedup): see
atom/model_ops/v4_kernels/doc/ATOM_V4_PAGED_DECODE_DESIGN.{zh,en}.md
(working tree, not committed in this PR — finalising for next push).

Bugs fixed in this PR

# Bug Fix
1 weights_mapping substring collision (381/2519 params silently skipped) WeightsMapper prefix-anchored remapping
2 wo_a FP8 shuffle after BF16 dequant (attn output cos=-0.002) quant_type=No to skip CK shuffle
3 Hash routing missing route_scale (FFN output 5.2× too small) topk_weights *= routed_scaling_factor
4 ActivationType.Swiglu causes 9× amplitude loss on gfx950 Use standard Silu + triton post-kernel clamp
5 shared_experts.w2 reduce_results mismatch with FusedMoE reduce_results=False + unified all_reduce
6 KV cache warmup pollution (stale data from dummy forward) Reset all KV/Compressor/Indexer buffers on start_pos=0
7 UE8M0 input quant rounding mismatch vs reference Switch input quant path to match reference; correct MoE routing scale
8 Weight loading: only one-way coverage check (orphan ckpt params undetected) Bidirectional coverage check + V4 hash-layer bias handling
9 Compressor state cache required per-decode roll memcpy pos % (2*ratio) ring buffer; consumer reads halves by block-id parity
10 n_committed = (start_pos + 1) // ratio dropped boundaries inside MTP-N decode windows n_committed = end_pos // ratio (covers token_num > 1 case)
11 PR #678's _sparse_attn_*_triton_kernels crashed TritonAMDGPUOptimizeDotOperands pass on gfx950 (block_h=2/4 below MFMA min tile 16×16×16) block_h = 16 in all three wrappers
12 dummy_run special path: warmup skipped per_req_cache slots → CUDAGraph blocker, divergent code path _populate_state_slot_mapping falls back to slot 0 on empty groups; warmup walks normal forward path
13 Wider CG capture ([1,2,4,8,...]) replay → Memory access fault by GPU across all ranks. Root cause: model_runner pads decode batches to graph_bs but V4 builder only filled [0:scheduled_bs] metadata entries; captured kernels iterate graph_bs * max_q_len slots at replay → padded slots read stale buffer → garbage kv_indptrkv_indices OOB Sentinel-pad protocol: per-token tensors (batch_id_per_token, swa_write_indices) carry -1 for padded slots; indptr cumsums repeat the last value. Consumer kernels skip on bid < 0 / kv_len == 0. _fill_csa_paged_compress clamps batch_id and block_idx before fancy indexing
14 Host .long() cast on stable int32 view broke captured graph: .long() allocates a fresh int64 tensor each call, captured graph reads the capture-time data_ptr at replay → stale memory read → GPU fault CG-safety dtype rule: stage int64 at the source (e.g. batch_id_per_token graduated to int64 — one buffer satisfies both PyTorch fancy-index and triton kernels). Never .long() / .to(int64) on the per-fwd path

top_k_per_row migration (indexer prefill + decode)

Indexer._score_topk_* previously used PyTorch topk over -inf-padded
logits. Migrated to aiter top_k_per_row_prefill / top_k_per_row_decode
(radix backend, parametric k) — uniform [total_tokens, index_topk] int32
layout for both paths, kernel honors per-row valid range natively so the
fill_(-inf) is gone.

Decode path (CG-friendly):

  • Pre-allocate [max_bs, max_model_len_idx] fp32 logits + [max_bs, index_topk] int32 indices buffers in forward_vars.
  • deepgemm_fp8_paged_mqa_logits writes valid cols [0, n_committed).
  • top_k_per_row_decode(k=index_topk) writes top-k indices + -1 sentinels for cols past per-row valid range; _post_process_topk width-masks the sentinels via width_mask_gpu.

Prefill path (eager-only):

  • top_k_per_row_prefill(cu_starts, cu_ends, k=index_topk) honors per-row causal window from fp8_mqa_logits output; same [N, index_topk] output layout as decode.

Bug found and fixed during the migration: the builder's compress_topk_src_gpu was indexing indexer_topk_batched with stride max_k = max(k_per_seq) (dynamic, derived from prefill's torch.topk output shape). Collapsing both paths to index_topk stride required unifying the index in builder; without the fix, GSM8K-50 dropped from 0.95+ to 0.76 (consumer gathered wrong cells).

Dependencies: ROCm/aiter#3012 exposes the k kwarg on top_k_per_row_{prefill,decode}. Without that PR the kwarg is silently ignored and aiter falls back to k=2048 — still correct but allocates oversized output.

Kernel parity validation: stand-alone tests at v4 shapes (k=1024, varying bs ∈ {1, 4, 16, 64, 128, 256} × ctx) all-pass against torch.topk reference. Files:

  • /app/logs_claude/deepseek_v4/test_top_k_per_row_decode_v4.py
  • /app/logs_claude/deepseek_v4/test_top_k_per_row_prefill_v4.py

MoE paths

Path Env var Status Notes
aiter fused_moe (CK) default broken (a16w4+Swiglu bug on gfx950) Fastest but broken
triton matmul_ogs ATOM_USE_TRITON_MOE=1 verified TPOT 0.33–0.52s/tok (4-seq batch, eager)
torch per-expert ATOM_V4_TORCH_MOE=1 verified Very slow, debug only

Note: ATOM_USE_TRITON_MOE=1 is mandatory for V4 accuracy. Without it GSM8K-50 drops from 0.95+ to ~0.6 (broken aiter MoE path on gfx950).

V4 Attention Backend (PR-A migration + Phase 2/3)

Selects between legacy per-seq Python dispatch and new batched V4AttentionBackend. The new backend removes ~256 GPU→CPU .item() syncs per forward and ~1200 per-layer torch.as_tensor H2D copies, prerequisite for CUDAGraph.

Variable Type Default Description
ATOM_V4_BACKEND str legacy new routes through V4AttentionBackend
ATOM_V4_BACKEND_LAYERS csv int "" (=all) Per-layer A/B bisect (e.g. 0,3,15,30)

Currently landed:

  • Phase 1a: swa_write Triton kernel (per-token positions % win, gated by use_new_v4_backend(layer_id))
  • Phase 2 partial: CPU-mirror metadata (cu_seqlens_q_cpu, state_slot_mapping_cpu, start_pos_per_seq_cpu) — eliminates per-seq .tolist() / .item() syncs in dispatch
  • State cache rewrite: update_compressor_states Triton kernel writes per-token at pos % STATE_SIZE; Compressor.forward reads A/B halves by block-id parity (no roll)
  • Batched compressor (SGLang plan style): attn_metadata_builder emits per-(ratio, overlap) CompressPlan (compress_plan + write_plan + cu_compress_cpu) in prepare_prefill / prepare_decode. Compressor.forward / update_compressor_states consume plan + state_slot_mapping; one batched kernel launch per layer per fwd replaces the 64 layers × num_seqs per-seq launches
  • Batched Indexer (Phase 2b-i + 2b-ii): hoisted Indexer Compressor out of per-seq loop. Indexer.forward_batched performs single fp8_mqa_logits + topk + width/future mask + offset across all seqs (byte-equal vs per-seq, validated on 31k+ A/B cases)
  • Phase 3 hoist (Step A/B/C + Tier 3/4): all per-fwd-invariant metadata constructed once in builder helpers (_attach_v4_per_fwd_meta, _build_v4_pack_meta_for_ratio, _build_v4_indexer_meta, _build_v4_gather_indices). V4Attention.forward reads from attn_metadata GPU views — zero per-layer torch.as_tensor. Init-time hoist for get_hip_quant / weights_scale / env-var checks. Module-level cache for V4_FORCE_UE8M0_QUANT / V4_USE_REF_QUANT / V4_AITER_HC_POST

Remaining (future PRs): full backend extraction (v4_attention_backend.py), CUDAGraph capture + replay (next PR — buffer pre-allocation + build_for_cudagraph_capture implementation).

Known limitations

  • CSA Indexer FP8 cache stays on cp_gather_indexer_k_quant_cache — decoupled from the new sparse_attn_v4_paged_decode kernel, not merged into unified_kv (different dtype). Out of scope for this PR.
  • AITER GEMM not tuned for V4 shapes: log shows many not found tuned config in /tmp/aiter_configs/... falling back to default — bounds throughput
  • per-req state cache 25.59 MB (HCA tail 60% / SWA 30% / CSA 9%) compresses --max-num-seqs capacity. SGLang dsv4-rebase has bf16 state + SGLANG_OPT_USE_ONLINE_COMPRESS (HCA ring 128→1, store online softmax-pool reduction state instead of raw KV) — cuts per-req to ≈ 1.3 MB. Future PR.
  • 46 params unloaded: 3× hash-layer e_score_correction_bias (expected) + 43× MTP params (PR5)

Files changed

File Change
atom/config.py V4-to-V3 config registry + V4 field re-injection
atom/models/deepseek_v4.py Full V4 model + multi-seq dispatch + state cache refactor + batched compressor/indexer + Phase 3 hoist (forward reads only from attn_metadata)
atom/model_loader/loader.py WeightsMapper auto-read + bidirectional coverage check
atom/model_ops/v4_kernels/ NEW: swa_write, update_compressor_states, fused_compress_attn, compress_plan (SGLang plan generator), paged_decode (V4 sparse decode kernel for unified KV pool), csa_packed_write (per-layer indexer→paged-offset translation, fixed-grid)
atom/model_ops/v4_backend_gate.py NEW: per-layer backend selector
atom/model_ops/attentions/deepseek_v4_attn.py V4 metadata builder + CPU-mirror metadata + context_lens plumb + compress_plans build + Phase 3 helpers + typed AttentionMetaData_DSV4 dataclass + _attach_v4_paged_decode_meta (3 indptr cumsums + SWA/HCA paged offsets) + build_for_cudagraph_capture (capture-time decode-shape synthesis with MTP-1 awareness) + CG-padding sentinel protocol (per-token tensors padded with -1; indptr padded by repeating last cumsum value)
atom/model_ops/moe.py ATOM_USE_TRITON_MOE gate + swiglu_limit passthrough
atom/model_ops/quant_v4.py UE8M0 input quant + FP4 e2m1 dequant
atom/model_ops/fused_moe_triton.py CDNA4MXScaleLayout fix + swiglu_limit clamp
atom/model_ops/sparse_attn_v4.py Explicit device= for multi-GPU + ROCm BLOCK_H=16 fix
atom/model_ops/block_manager.py Per-req cache abstraction (PR3-pre2a)
atom/utils/envs.py ATOM_V4_BACKEND / ATOM_V4_BACKEND_LAYERS
atom/utils/debug_helper/ Generic env-gated dump / compare / ref-patch
docs/environment_variables.md V4 backend env doc

Test plan

  • Single prompt English/Chinese 512 tokens — coherent output
  • 4-prompt batched inference (256 tokens each) — coherent outputs across all slots
  • Byte-equal kernel-vs-reference for update_compressor_states (15/15 cases: prefill + decode + MTP)
  • fused_compress_attn parity vs reference (30 cases: single-seq, batched bs=4/8, MTP-3, HCA) — max_diff ≤ 4.77e-7
  • Indexer.forward_batched byte-equal vs per-seq path (31k+ A/B cases)
  • lm_eval GSM8K 100-sample × 3-shot (eager) — 0.96 ± 0.020 flexible-extract / strict-match
  • CUDAGraph capture + replay: [1,2,4,8,16,32,64] capture sizes + --max-num-seqs 64, GSM8K-50 = 0.98 ± 0.020, capture cost 2.7s, ~4.3× TPOT speedup vs eager
  • paged_decode kernel parity vs reference (3/3 cases bit-exact)
  • csa_packed_write kernel parity (7/7 cases pass)
  • swa_write rewrite parity (5/5 cases: decode / MTP / long-prefill / sentinel-only / V4-Pro head_dim)
  • lm_eval GSM8K full (1319 samples) — to run with CG enabled in next push

valarLip added 13 commits April 24, 2026 16:13
…arity

Adds the foundational scaffolding for DeepSeek-V4-Pro support — a major
architecture shift from V3.2 with mHC residuals, hybrid CSA+HCA attention,
hash routing, and grouped output LoRA. PR1 ships the eager-mode model code
with torch fallback kernels, validated against the official inference
implementation at bit-exact parity (max_abs_diff = 0.0).

Scope (PR1 only):
- New atom/models/deepseek_v4.py: full Compressor / Indexer / Attention /
  Gate / Expert / MoE / Block / MTPBlock / ParallelHead / Transformer port
  (~1200 lines). Single-rank only; plain nn.Linear / nn.Embedding for now.
- New atom/model_ops/sparse_attn_v4.py: torch fallbacks for sparse_attn
  and hc_split_sinkhorn (Sinkhorn-Knopp projection on Birkhoff polytope).
- New atom/model_ops/quant_v4.py: torch fallbacks for FP8/FP4 inplace
  QAT round-trip and Walsh-Hadamard transform (replaces fast_hadamard_transform
  which doesn't build on ROCm).
- Register DeepseekV4ForCausalLM in support_model_arch_dict.

Out of scope (tracked for PR2-6):
- Real HF checkpoint loading (PR2 = FP4 e2m1 loader, PR3 = TP + KV cache).
- AITER sparse_attn kernel (PR4; spec at
  /app/logs_claude/aiter_v4_sparse_attn_spec.md, AITER team kicked off).
- MTP integration with EagleProposer (PR5).
- @support_torch_compile + CUDAGraph + openai_server (PR6).

Verification: /app/logs_claude/v4_pr1_verify.py monkey-patches the reference's
TileLang kernel imports with our torch fallbacks, copies the same dummy
state_dict into both models, and runs prefill + decode side-by-side. 259
tensors match exactly; max_abs_diff = 0.0 on logits.
DeepSeek-V4-Pro stores routed expert weights as packed FP4 e2m1 (int8 with
2 values per byte, low nibble first) plus per-block ue8m0 scale (block size
32 along input dim). This commit adds `dequant_fp4_e2m1(packed, scale)` in
atom/model_ops/quant_v4.py — a pure-torch unpacker that mirrors convert.py
exactly but produces BF16 directly instead of repacking into FP8.

Validated bit-exactly against an independent reference unpack on a real
22M-element expert tensor from the on-disk checkpoint. Also regression-
tested across 5 different shapes/positions (w1/w2/w3 in first/mid/last
layer + MTP). All produce values that lie exactly on the FP4 e2m1 grid.

Scope: this is the standalone dequant utility. Wiring it into the model
loader's safetensors pipeline + tying it to specific param names happens
in PR3 alongside TP-aware expert sharding.

Test: /app/logs_claude/v4_pr2_dequant_test.py
Result: max_abs_diff = 0.0 (bit-exact)
PR3a: replace nn.Linear / nn.Embedding with ATOM tensor-parallel-aware
classes for the BF16 projections in Attention, Indexer, and the model
embedding. Same `weight` parameter naming so dummy state_dicts continue
to load. At TP=1 ATOM's tgemm.mm produces bit-identical output to F.linear,
so PR1's reference parity (max_abs_diff = 0.0) still passes.

Layers refactored (8 total):
- DeepseekV4Model.embed:           nn.Embedding -> VocabParallelEmbedding
- DeepseekV4Attention.wq_a:        nn.Linear    -> ReplicatedLinear
- DeepseekV4Attention.wq_b:        nn.Linear    -> ColumnParallelLinear
- DeepseekV4Attention.wkv:         nn.Linear    -> ReplicatedLinear  (single shared MQA head)
- DeepseekV4Attention.wo_a:        nn.Linear    -> ColumnParallelLinear
- DeepseekV4Attention.wo_b:        nn.Linear    -> RowParallelLinear (with all-reduce)
- Indexer.wq_b:                    nn.Linear    -> ColumnParallelLinear
- Indexer.weights_proj:            nn.Linear    -> ColumnParallelLinear

Deferred to later PRs (intentional):
- Compressor.wkv/wgate (fp32) -> PR3c with quant_type wiring
- ParallelHead.weight (fp32 LM head) -> PR3c
- Expert.w{1,2,3} -> PR3b (FusedMoE wholesale rewrite)
- MoE.gate.weight (used as raw Parameter, not Linear class) -> kept

Verification: /app/logs_claude/v4_pr1_verify.py (now GPU mode with
init_dist_env) shows max_abs_diff = 0.0 for prefill + decode against
reference at TP=1.
… for real ckpt

PR3c delivers end-to-end real-checkpoint loading for DeepSeek-V4 attention
layers via ATOM's existing FP8/FP4 GEMM infrastructure.

What works after this commit (validated on real /data/DeepSeek-V4-Pro/):
- DeepseekV4ForCausalLM(atom_config) auto-builds a V4QuantConfig that maps
  routed-experts -> per_1x32 (FP4) and overrides wo_a / Compressor.wkv /
  Compressor.wgate / indexer.weights_proj -> bf16 (no quant). Everything
  else inherits the global FP8 (per_1x128) spec from the HF quantization_config.
- load_weights(weights) walks an iterable of (name, tensor) pairs and:
    * Remaps ATOM's `weight_scale` -> on-disk `scale` naming.
    * Special-cases wo_a: dequantizes FP8+scale -> BF16 on the fly so the
      grouped-LoRA einsum (which aiter doesn't support in FP8) works.
    * Dispatches to ATOM Linear's weight_loader for FP8 / FP4 / BF16 paths.
    * Skips params with shape mismatch (e.g. expert nn.Linear waiting for
      PR3b's FusedMoE refactor) without crashing.
- All 23 attention parameters (FP8 q/kv proj + FP4 indexer + BF16 wo_a + fp32
  compressor) load successfully on real layer-2 of the V4 checkpoint.

Threading changes:
- DeepseekV4Args gains `quant_config: Optional[Any] = None`.
- DeepseekV4Attention / Indexer / Compressor / Block / MTPBlock / DeepseekV4Model
  now accept `prefix: str = ""` and pass `quant_config + prefix` down to each
  ATOM Linear constructor so per-layer quant lookup works.

Backward compatibility:
- When `args.quant_config is None` (toy / dummy validation), V4QuantConfig
  retains its `QuantType.No` global — Linear layers stay BF16 and the PR1
  bit-exact reference parity test (max_abs_diff = 0.0) still passes.

Remaining gaps for end-to-end real-ckpt forward (tracked in design doc):
- PR3b: replace MoE/Expert with FusedMoE so 384 expert FP4 weights load.
- PR3d: refactor V4 attention.forward to accept 2D [num_tokens, dim] input
  (ATOM TP linears require 2D — current 3D path raises "GEMM not supported").
PR3d adapts V4 model to ATOM's scheduler convention: model.forward consumes
flat 2D `[num_tokens, dim]` tokens (single sequence implicit B=1), matching
how ATOM's ModelRunner / scheduler pass tokens. This unblocks ATOM Linear's
quantized GEMM kernels (which only accept 2D `[M, K]` input) and enables
end-to-end real-checkpoint forward.

What changed:
- DeepseekV4Attention.forward(x, start_pos): now accepts 2D [num_tokens, dim].
  Internally adds a B=1 dim only where needed (RoPE, sparse_attn). The
  grouped-LoRA einsum string changes from "bsgd,grd->bsgr" to "sgd,grd->sgr".
- Compressor.forward / Indexer.forward: accept 2D x; auto-unsqueeze to 3D
  internally for backward compatibility with the existing logic.
- Block.hc_pre / hc_post + ParallelHead.hc_head: refactored to be
  shape-agnostic in leading dims (use negative indexing on flatten / sum).
  Both 4D `[B, S, hc, D]` (legacy reference path) and 3D `[num_tokens, hc, D]`
  (ATOM path) work.
- ParallelHead.get_logits: 2D path takes last token via `x[-1:]`; 3D path
  preserves `x[:, -1]` for legacy [B, S, D] inputs.
- MTPBlock.forward: 2D-aware via `e.unsqueeze(-2)` for hc-dim broadcast.
- DeepseekV4Model.forward: auto-flattens 2D `[1, S]` input_ids to 1D `[S]`
  for the new convention; rejects B>1 (proper multi-sequence batching needs
  attn_metadata, deferred).

Validated:
- PR1 reference parity (toy 4-layer dummy weights at B=1 S=32):
  max_abs_diff = 0.0 — still bit-exact after the 2D refactor.
- PR3d end-to-end on REAL V4 weights:
  + Built DeepseekV4ForCausalLM (4 layers, real V4 dims, ~105B params)
  + load_weights() loaded 36 layer-2 params; 23/23 attn params nonzero
  + attn(x_2d=[16, 7168], start_pos=0) → output [16, 7168] bf16
  + No NaN/Inf; output range [-2.94, 3.08], abs mean 0.42 (sensible)
  + This is the first successful V4 attention forward on real weights via ATOM

Test scripts (under /app/logs_claude/):
- v4_pr1_verify.py — toy parity (now uses B=1 + ATOM 2D path)
- v4_pr3d_layer_e2e.py — real-weight 2D forward end-to-end
- v4_pr3c_layer0_test.py — per-Linear validation against real ckpt

Remaining for full model end-to-end:
- PR3b: MoE → FusedMoE so 384 expert FP4 weights load (currently shape-skipped)
- Multi-sequence support via attn_metadata (currently single-sequence implicit B=1)
PR3b enables ATOM's FusedMoE for V4's 384 routed experts so FP4 expert
weights can load via the existing aiter `gemm_a4w4_quant` kernel and
shard across TP/EP ranks. Also extends `select_experts` in moe.py to
support V4's `sqrtsoftplus` scoring with `e_score_correction_bias`.

Changes in atom/model_ops/moe.py:
- `FusedMoE.select_experts` now handles `scoring_func="sqrtsoftplus"`:
  routing_weights = sqrt(softplus(router_logits)) + topk + renormalize.
  Mirrors the V4 reference Gate.forward exactly for non-hash layers.

Changes in atom/models/deepseek_v4.py:
- Dual-path MoE: when `quant_config` is set AND ATOM's global atom_config
  is initialized, MoE uses ReplicatedLinear gate + FusedMoE experts +
  ATOM-Linear shared_experts. Otherwise falls back to the original manual
  per-expert nn.Linear path so PR1 toy validation stays bit-exact (the
  reference test runs without ATOM's ModelRunner setting the global config).
- Expert class accepts `quant_config + prefix`: when set, w1/w2/w3 become
  ColumnParallelLinear/RowParallelLinear (FP8 path); else nn.Linear (toy).
- DeepseekV4ForCausalLM.get_expert_mapping() returns the (param_name,
  weight_name, expert_id, shard_id) tuples mapping V4's `w1/w2/w3` ckpt
  names to FusedMoE's merged `w13_*`/`w2_*` params.
- load_weights() walks expert_mapping first to dispatch routed expert
  tensors via FusedMoE's per-expert weight_loader, then handles the rest:
    * ATOM `weight_scale` ↔ on-disk `scale` rename (existing)
    * ATOM `gate.e_score_correction_bias` ↔ on-disk `gate.bias` rename (NEW)
    * `wo_a` FP8 → BF16 dequant on load (existing)

Validated:
- PR1 toy parity: max_abs_diff = 0.0 (manual MoE path still bit-exact).
- PR3d e2e: real layer-2 attn + 2D forward still works.
- PR3b new: under stub atom_config, FusedMoE path activates correctly.
  Layer-3 (non-hash, real V4 dims): gate + e_score_correction_bias +
  shared_experts (6/6) loaded; FusedMoE expert mapping returns 1152
  entries (384 experts × {w1,w2,w3}).

Known limitations (deferred):
- Hash routing (layers 0/1/2): tid2eid table is loaded but routing logic
  still falls through to sqrtsoftplus path → INCORRECT for hash layers.
  Proper hash routing requires either a custom path through FusedMoE
  or a pre-computed (topk_weights, topk_ids) injection point.
- Multi-sequence batching via attn_metadata (currently single-sequence implicit B=1).

Test: /app/logs_claude/v4_pr3b_fusedmoe_test.py
… prefix

Bug: `make_v4_quant_config` matched `"ffn.experts." in layer_name` (with
trailing dot). FusedMoE.__init__ asks for the layer's quant_type with
prefix `layers.N.ffn.experts` (NO trailing dot — it's the parent module
of the per-expert weights, not a per-expert lookup). The check failed,
so FusedMoE inherited the global FP8 (per_1x128) spec and allocated
the routed expert weights as `float8_e4m3fn` instead of `float4_e2m1fn_x2`.

Symptom in PR3b validation output before the fix:
  FusedMoE experts: 3/5 nonzero  (loader couldn't dispatch FP4-shaped
  on-disk tensors into FP8-typed model params; shape mismatch silently
  skipped them)

After the fix:
  experts.w13_weight: (385, 6144, 3584) torch.float4_e2m1fn_x2 ✓
  experts.w13_weight_scale: (385, 6144, 224) torch.float8_e8m0fnu ✓
  experts.w2_weight:  (385, 7168, 1536) torch.float4_e2m1fn_x2 ✓
  experts.w2_weight_scale:  (385, 7168, 96) torch.float8_e8m0fnu ✓
  e_score_correction_bias: (384,) torch.float32 ✓

Match condition tightened to `".ffn.experts" in layer_name` so it
catches BOTH `layers.N.ffn.experts.M.w1` (per-expert Linear lookups)
AND `layers.N.ffn.experts` (FusedMoE parent module lookup).

Note: a separate aiter-side issue (HSA_STATUS_ERROR_EXCEPTION on FP4
expert weight_loader, traced to a `direct_copy_kernel` with grid size
exceeding HW limits) prevents end-to-end FP4 expert load testing on
this box. The dtype/shape correctness above is verified by inspecting
the constructed module's params directly.

Validated:
- PR1 toy parity: max_abs_diff = 0.0 (manual MoE fallback unaffected)
- PR3d real-attention forward: still works
PR3b's expert weight loader had three bugs that caused weights to load as
zero or be silently dropped:

1. **Expert mapping pattern mismatch**: `make_expert_params_mapping` returns
   `(param_part="experts.w13_", weight_part="experts.0.w1.", ...)` — substring
   substitution, not endswith. The old code built `f".experts.{e}.{suffix}"`
   which never matched. Switched to longest-prefix substring substitution
   matching the standard ATOM loader pattern.

2. **Scale dtype zero-fill**: copying `torch.float8_e8m0fnu` into a `uint8`
   destination via `copy_()` silently produces zeros (mismatched dtype, no
   reinterpret). FusedMoE allocates `w13_weight_scale` as uint8; force a
   `.view(torch.uint8)` on the e8m0 source before passing to the loader.

3. **Param suffix `_scale` vs `.weight_scale`**: after substring sub,
   `experts.0.w1.scale` becomes `experts.w13_scale`, but the FusedMoE param is
   `experts.w13_weight_scale`. Added `_scale` → `_weight_scale` post-fix.

Plus: gracefully slice on-disk gate.weight / gate.bias when the test caps
n_routed_experts below the checkpoint size (no-op in real serving).

Verified:
- v4_pr3b_fusedmoe_test: 32 params loaded, 5/5 expert + 6/6 shared nonzero
- v4_pr3d_layer_e2e: real attention forward still works
- v4_pr1_verify: bit-exact reference parity preserved (0.0 max diff)
…uting_function

V4 uses tid2eid hash lookup (instead of gate-logit topk) for routing in
layers where compress_ratio implies hash layer (first 3 layers in standard
config). Previously, MoE just declared tid2eid for weight loading but
inference fell through to sqrtsoftplus path → wrong routing for those layers.

This commit:

- Adds an early `custom_routing_function` branch to FusedMoE.select_experts
  (it was in the signature but never honored — the non-grouped path went
  straight to scoring_func dispatch). Now any non-None custom fn takes
  precedence and returns (topk_weights, topk_ids).

- Adds DeepseekV4MoE._hash_topk(): topk_ids = tid2eid[input_ids],
  topk_weights = sqrtsoftplus(router_logits) gathered + renormalized.
  Stashes input_ids on self before the experts() call so the closure can
  index tid2eid; clears immediately after.

- For hash layers: assigns experts.custom_routing_function = self._hash_topk
  in MoE.__init__ so FusedMoE picks it up via the moe_forward custom op
  → forward_impl_graph → quant_method.apply → select_experts plumbing.

Verified:
- PR3e (new): synthetic tid2eid → _hash_topk produces exact expected ids,
  renormalized weights match reference math (max_abs_diff = 0.0)
- PR3e: FusedMoE.select_experts honors custom_routing_function correctly
- PR1 toy parity: still 0.0 max diff (hash path is opt-in via is_hash_layer)
- PR3b FusedMoE load: 32 params, all nonzero (no regression)
- PR3d real attn forward: still works (non-hash layer)
… real ckpt

Three changes converging on the first working V4 layer forward:

1. **weights_mapping**: Add class-level rename dict so the standard ATOM
   loader (`atom.model_loader.loader.load_model`) can ingest V4 ckpt names
   without per-model loader.py changes. `.gate.bias` →
   `.gate.e_score_correction_bias`, `.scale` → `.weight_scale_inv`. Loader's
   built-in `weight_scale_inv` → `weight_scale` rename then completes the
   path. Real serving via ModelRunner now works for non-wo_a layers.

2. **process_weights_after_loading hook**: After my custom `model.load_weights`
   finishes copying tensors, walk all submodules and call
   `quant_method.process_weights_after_loading(layer)` (or
   `layer.process_weights_after_loading()` if no quant_method).

   Without this, FusedMoE's `shuffle_weights` step is skipped and the FP4
   ck_moe kernel reads stale weight layout — manifested as
   HSA_STATUS_ERROR_EXCEPTION mid-forward. Standard loader.py calls this for
   us; my custom loader had to replicate it.

3. **PR3f end-to-end test** (logs_claude/v4_pr3f_block_e2e.py):
   - Build 1 dense layer (compress_ratios=[0]) with 8 routed experts
   - Load real layer-3 weights (32 target params, 33/33 nonzero)
   - Build mHC residual `[8 tokens, hc_mult=4, dim=7168]`
   - Call Block.forward(x, start_pos=0, input_ids)
   - Output: shape preserved, range [-4.1, 4.6], abs mean 0.81, no NaN/Inf

This is the first end-to-end forward through V4's full layer:
attention (FP8 wq/wkv + BF16 wo grouped LoRA + indexer) + FusedMoE (FP4
experts via aiter ck_moe + sqrtsoftplus routing + bias correction +
shared expert) + mHC pre/post Sinkhorn projections.

Confirmed no regression on PR1/PR3b/PR3d/PR3e.
…kpts

ModelRunner uses atom.model_loader.loader.load_model() — not the model's
custom load_weights(). This commit closes that gap so real serving via
openai_server works end-to-end:

1. **Expand weights_mapping with prefix renames**: V4 ckpt has bare names
   (`embed.`, `layers.`, `norm.`, `head.`, `hc_head_`) but our params live
   under `self.model = ...`. Add prefix substitutions so the loader's
   `model.get_parameter(name)` lookup hits the right attribute path.

2. **Fix dtype-mismatch silent zero in FusedMoE._load_w13/_load_w2**:
   PyTorch's `tensor.copy_()` between mismatched float8/uint8 dtypes silently
   writes zeros. V4's per-1x32 weight scales are stored as `float8_e8m0fnu`
   on disk but FusedMoE allocates them as `uint8` (raw byte storage). Force
   a `.view(torch.uint8)` reinterpret on the source so the bytes round-trip
   correctly. This is a pre-existing bug that was masked because V2/V3 use
   `float32` scales — V4 is the first ATOM model to use e8m0/e4m3 scales.

Verified:
- PR3i (new): standard load_model() loads V4 layer-0 from full 805GB ckpt
  index — 43/43 model params nonzero (100%), 5GB selective load.
- PR3g (new): full Model.forward(input_ids) → logits on real ckpt.
  Output shape (1, 129280), range [-14.2, 15.4], std 3.05, no NaN/Inf.
- PR3h (new): hash layer (layers 0/1/2) Block.forward works on real
  layer-0 ckpt (tid2eid loaded, 773423/775680 nonzero entries, real
  per-token expert assignments diverge from default sqrtsoftplus path).
- All 5 prior tests (PR1/PR3b/PR3d/PR3e/PR3f) still pass — no regression.

Net result: V4 inference pipeline is now production-ready for real ckpt
loading + forward; remaining gap is multi-layer + multi-batch attn metadata
+ AITER sparse_attn (parallel work).
…hook

PR3i shipped "100% nonzero params" but never ran forward through the
standard-loader path. Verifying with PR3j (new) revealed wo_a values were
2768× too large — `torch.copy_(BF16_dst, FP8_src)` does an FP8→BF16 dtype
conversion but SKIPS the per-128-block scale multiplication. Result: raw
FP8 e4m3 max value (448.0) lands in the BF16 weight buffer instead of the
true ~0.04 attention-init magnitude.

Fix: stop forcing wo_a to no_spec/BF16 in V4QuantConfig. Let it allocate
as FP8 ColumnParallelLinear so the standard FP8 loader fills both
`wo_a.weight` (FP8) and `wo_a.weight_scale` (e8m0) correctly. Then
DeepseekV4Attention.process_weights_after_loading dequants in place,
replacing weight with BF16 + dropping the scale param. Forward continues
to use BF16 weight in the grouped LoRA einsum (aiter has no FP8 grouped
einsum).

Also removes the manual wo_a special-case from custom load_weights() —
both load paths (custom + standard) now converge through the same
process_weights_after_loading dequant.

Verified by PR3j parity test:
- Custom path wo_a: abs.mean=0.0214, abs.max=0.4062
- Standard path wo_a: abs.mean=0.0214, abs.max=0.4062 (BIT-EXACT)
- Standard-loader Model.forward → logits range [-17.9, 15.8], std 3.04
- Magnitude ratio: 1.00 (was 2768× before fix)
- All 9 tests pass — no regression.

This was a silent corruption that PR3i's "params nonzero" check missed.
The lesson: nonzero != correct. Always verify with forward.
Major changes enabling correct V4 inference (single-prompt verified with
512-token coherent output in both English and Chinese):

Model fixes:
- WeightsMapper prefix-anchored remapping (fixes 381 silently-skipped params)
- wo_a FP8→BF16 dequant with quant_type=No to prevent CK shuffle corruption
- Hash routing (first 3 layers) now applies route_scale=2.5
- shared_experts reduce_results=False + unified all_reduce in MoE.forward
- KV cache reset on start_pos=0 with score_state=-inf initialization
- TP-correct head/group counts for Attention and Indexer

MoE routing:
- Standard Silu activation (not Swiglu — aiter a16w4+Swiglu has 9× amplitude
  loss on gfx950). swiglu_limit clamping done in triton post-kernel.
- ATOM_USE_TRITON_MOE=1: triton matmul_ogs path with swiglu_limit clamp
- ATOM_V4_TORCH_MOE=1: per-expert torch fallback with FP4 dequant (slow)
- GFX950MXScaleLayout→CDNA4MXScaleLayout fix in fused_moe_triton.py

Loader improvements:
- WeightsMapper auto-read from model class attribute
- Post-load WARNING listing all unloaded params
- Shape-mismatch raises RuntimeError instead of silent skip

Config:
- deepseek_v4→deepseek_v3 registry mapping with V4 field re-injection
- Robust from_hf_config with getattr defaults

Known limitations:
- Single-sequence only (kv_cache[:1,...] hardcoded); batch>1 needs PR3
- Multi-request KV isolation pending scheduler integration
- TPOT ~213ms with --enforce-eager (no CUDAGraph)
Oseltamivir added a commit to SemiAnalysisAI/InferenceX that referenced this pull request Apr 26, 2026
dsv4-fp4-mi355x-atom (ROCm/ATOM#650 PR1, single-sequence at TP=8 with
torch-fallback hc_pre because aiter mhc_pre crashes on this image)
runs at ~5 min per request in steady state. With 1k1k at 12 prompts
plus 8k1k at the same shape, the full sweep can exceed the 300-min
cap that #1148 set for the SGLang-DSv4 path.

Bump both the SLURM allocation in runners/launch_mi355x-amds.sh and
the GitHub Actions timeout-minutes in benchmark-tmpl.yml together —
either expiring first kills the job, so they need to stay aligned.

Note: this is a global bump that affects every MI355X benchmark and
every job that uses the shared workflow template, not just the dsv4
ATOM one. Drop back to 300 once the slow paths are gone (PR4
CUDAGraph + a working aiter MHC).
valarLip added 12 commits April 28, 2026 03:51
…202)

Upstream ref (deepseek-ai/DeepSeek-V4-Pro@a1fd202) changed shared_experts
from no swiglu_limit to swiglu_limit=args.swiglu_limit, making it consistent
with routed experts.
…witch RoPE to aiter

- DeepseekV4ForCausalLM/Model/Block/MTPBlock/Attention/Compressor/Indexer
  now accept `positions: torch.Tensor` instead of `start_pos: int`; internal
  ring-buffer indexing still derives `start_pos = positions[0].item()` (full
  per-request KV slot management deferred to PR3).
- New `_V4RoPE` wraps aiter `rope_cached_positions_{,2c_}fwd_inplace`,
  driven by per-token positions. Cos/sin cache built via V4's exact YaRN math
  (`_precompute_freqs_cis`); kept symmetric to `_apply_rotary_emb` by working
  on the pre-sliced rope tail.
- `_build_cos_sin_cache` is lru-cached on (rope params, dtype, device) so the
  3 distinct rope param sets (HCA / CSA / Dense) share one GPU tensor across
  all 62 layers instead of 62 register_buffer copies (~16 GB OOM otherwise).
- Inverse RoPE on the attention output keeps `_apply_rotary_emb` (aiter has
  no inverse kernel); the complex freqs slice is rebuilt on demand from the
  cos/sin cache via `_V4RoPE.freqs_for_positions`.
- Verified: simple_inference single-prompt CN 256 tokens coherent.
Generalize the GDN per-request state decoupling (#602) into a complete
model-agnostic KV abstraction owned by the AttentionMetadataBuilder
hierarchy. ModelRunner is now blind to attention type — it walks modules
and dispatches; per-attention-type tensor layouts (MLA 576-dim packed,
GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2
indexer cache, GDN per-req mamba state) all live next to their
respective builder.

ModelRunner net: -526 LOC. The if/elif chains over use_mla /
is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes,
allocate_kv_cache, and the binding loop are all gone. Future stateful
attentions (DeepseekV4 ring buffer + compressor state) plug in by
subclassing AttentionMetadataBuilder without touching scheduler /
block_manager / ModelRunner.

New AttentionMetadataBuilder hooks (defaults are no-ops):
  - compute_per_req_cache_bytes() / slots_per_req()
      bytes/slot for the per-request state pool
  - allocate_per_req_cache(num_slots)
      dict of named per-request state tensors
  - compute_block_bytes()
      per-block bytes for the KV pool budget
  - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers)
      dict of named primary KV cache tensors (kv_cache, kv_scale,
      index_cache, aligned_index_dim, _kv_layer_cache_store)
  - build_kv_cache_tensor(layer_id, module)
      vLLM-style KVCacheTensor for one module, or None if foreign type;
      owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)

Builder overrides:
  - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module
  - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer
  - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot
    pool; chains super() for MHA modules in hybrid models. Absorbs the
    formerly-runner-owned gated_delta_net_state_shape/dtypes helpers
    and the side-effect init of full_attention_interval / num_full_attn
    / num_gdn_attn_state.

Naming distinguishes group (per-request unit) from slot (raw tensor
index). One group occupies `slots_per_req()` contiguous slots in the
underlying tensor:
  Sequence.mamba_state_slot     -> .per_req_cache_group
  seq.mamba_enabled             -> .has_per_req_cache
  batch.mamba_state_slots       -> .per_req_cache_groups
  BlockManager.mamba_*          -> .per_req_cache_*  (free pool, accounting)
  config.mamba_equiv_per_req    -> .per_req_cache_equiv_blocks
  config.num_mamba_groups       -> .num_per_req_cache_groups
  ModelRunner.max_mamba_slots   -> .max_per_req_cache_slots  (tensor dim)

Removed (moved to builders):
  ModelRunner._compute_mamba_per_slot_bytes
  ModelRunner.gated_delta_net_state_shape / _dtypes

Sanity check: ModelRunner.__init__ now asserts that any builder
returning compute_per_req_cache_bytes() > 0 has its model_type
registered in InputOutputProcessor._per_req_cache_model_types(),
catching the silent-corruption misconfiguration where a stateful
attention is added but Sequence-construction never gets the
has_per_req_cache=True flag.

Verified:
  - tests/test_per_req_cache_decoupling.py: 24/24 pass
  - core suite (block_manager, sequence, scheduler, request,
    io_processor_fanout, prefix_cache_accuracy): 118/118 pass
  - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion
    quality unchanged
  - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent):
      flexible-extract = 0.8757 +/- 0.0091  (baseline 0.8711 from #602)
      strict-match     = 0.8605 +/- 0.0095
V4 backend (DeepseekV4Backend + DeepseekV4AttentionMetadataBuilder)
plus migration of state-cache buffers to ATOM's per_req_cache pool:

  - pre2a: 6 Compressor state buffers (kv_state + score_state for
    CSA Main / CSA Indexer / HCA Main).
  - pre2c-A: SWA window per layer (paper §3.6.1 state cache, every
    layer has SWA branch in V4-Pro). Attention.kv_cache splits into
    Attention.swa_kv (per_req_cache) + Attention.kv_cache (compressed
    entries only, still register_buffer; pre2c-B will move under
    block_table).

Validated single-prompt 64-token Chinese generation (V4-Pro tp=8,
triton MoE, enforce-eager) — output indistinguishable from baseline.
Strict-paper §3.6.1 split: compressed entries (CSA Main, CSA Indexer,
HCA Main) move from per-layer register_buffer to block-table-indexed
pools owned by DeepseekV4AttentionMetadataBuilder.

  - block_size = lcm(m, m') = 128 original tokens, plumbed via Config
    override on model_type=deepseek_v4 detection.
  - Three classical pools:
      v4_csa_main_kv [num_blocks, n_csa, k1=32, head_dim=512]
      v4_csa_idx_kv  [num_blocks, n_csa, k1=32, idx_head_dim=128]
      v4_hca_main_kv [num_blocks, n_hca, k2=1, head_dim=512]
    Per-layer slice bound to Compressor.kv_cache / Indexer.kv_cache.
  - V4 model adds _v4_scatter_compressed / _v4_gather_compressed helpers
    and fetches block_table from forward_context. Compressor.forward
    scatters writes into block-table slots; Indexer.forward + decode
    sparse_attn input gather committed entries from blocks.
  - Indexer + 1-slot warmup fallback register_buffer pattern same as
    pre2a Compressor.kv_state.
  - Attention.kv_cache attribute removed entirely (compressed entries
    no longer co-located on the Attention module).

Validated single-prompt 64-token Chinese generation (V4-Pro tp=8)
unchanged from pre2c-A baseline.
V4 forward now handles ATOM ragged-batch input with per-seq slot +
block_table routing. Single-seq behavior unchanged; concurrent
batched multi-seq prefill + decode verified end-to-end on 4 prompts.

Changes:
  - Builder prepare_decode/prepare_prefill populate cu_seqlens_q,
    block_tables, and v4_slot_indices (new per-seq metadata attached
    to AttentionMetaData via dynamic attribute).
  - _v4_get_block_table replaced with _v4_get_seq_metadata returning
    (block_tables, slot_indices, cu_seqlens_q, num_seqs).
  - Compressor.forward + Indexer.forward signatures: add slot,
    block_table args. Per-slot indexing via [slot:slot+1, ...]
    replaces hardcoded [:1, ...] / [:bsz, ...].
  - Attention.forward: batched Linear projections + RoPE on full flat
    tensor; per-seq loop slices (cu_seqlens_q) and dispatches SWA write,
    Compressor scatter, Indexer + sparse_attn with each seq's slot +
    block_table. Per-seq state-cache reset on prefill (start_pos==0)
    only zeros that seq's slot — no cross-seq pollution.
  - ParallelHead.get_logits: pick last-token-per-seq via cu_seqlens_q
    (fixed long-standing single-seq assumption that always returned
    only x[-1] regardless of batch size).

Validated MAX_NUM_SEQS=4 concurrent batched inference: 4 prompts
processed in parallel produce independent coherent outputs.
Three independent bugs caused V4 to ramble on edge-confidence prompts
(e.g. "1+2+3=?" output garbled despite 3/4 batch=4 prompts looking OK).
Single-prompt output now matches reference byte-equal on the first 5
tokens and produces "The sum is: 1 + 2 + 3 = **6**." (was: "I'll happily
provide a step-by-step breakdown..." ramble).

Bug 1 (quant_v4.py) — act_quant_inplace ue8m0 path used `ceil(log2)`
(matched TileLang reference) but ref_full_generate.py and aiter both use
round-to-even via f32_to_e8m0/e8m0_to_f32. The 1-binade gap appeared as
~0.002 cos drift on KV path, accumulating across 60 layers.

Bug 2 (moe.py) — FusedMoE.select_experts sqrtsoftplus path renormalized
topk_weights but never applied `* routed_scaling_factor`. The hash routing
path (V4 layers 0-2) does this internally, hiding the bug for hash layers.
Reference Gate.forward (model.py:583) applies the multiply for every
non-softmax routing path. Without the scale, layer 3+ MoE outputs were off
by 1.5x, producing the visible cos jump from 1.0 (layer 0/2) to 0.98
(layer 3+).

Bug 3 (deepseek_v4.py) — DeepseekV4Args.from_hf_config did not read
scale_fmt; HF config.json doesn't carry the field, only inference/config.json
does. Default to "ue8m0" matching reference ModelArgs (inference/model.py:40)
so act_quant_inplace's ue8m0 path is actually exercised.

Also folds in previously-validated V4 cleanups that were sitting in the
working tree:
  - _RMSNorm → ATOM RMSNorm (mark_trace + torch.compile friendly)
  - Indexer wq_b/weights_proj: ColumnParallelLinear → ReplicatedLinear
    (matches sglang/upstream; avoids extra all_reduce on index_score)
  - Block.hc_post defaults to torch (aiter mhc_post drift, opt-in via
    V4_AITER_HC_POST=1; see notes/12)
  - _torch_moe_forward: ue8m0 round-trip on input to mirror reference
    Expert.forward (act_quant before fp4_gemm), gated by V4_USE_REF_QUANT=1

Diagnosis path: notes/14_debug_1plus2plus3.md → notes/19_full_fix_verified.md
… cleanup

New module atom/utils/debug_helper/ provides reusable primitives for forward
bisecting and batch-invariance investigation. All entry points are no-ops
when their controlling env var is unset, so they are safe to leave wired
into production paths (model_runner.py post-load).

Components
  - dump.py        install_block_forward_hooks (multi-class + multi-call),
                   maybe_dump_weights_and_exit, maybe_log_topk
  - compare.py     cos_max (DOUBLE precision — fixes fp32 cos > 1.0 bug),
                   slot_split, compare_slots, pick_prefill_call,
                   schema_diff, plus CLI subcommands:
                     slot-invariance / ref-vs-target / layer-bisect / schema
  - ref_patch.py   patch_method / patch_block_forward / patch_module_dump
                   context managers for instrumenting read-only references
  - 9 ATOM_FWD_DUMP_* / ATOM_WEIGHT_DUMP_* / ATOM_DEBUG_TOPK env vars
    registered in atom/utils/envs.py "Debug Dump" section

Wired into model_runner.py with a 3-line post-load call (no-op default).

V4 model cleanup
  - Convert all nn.Parameter() constructors in deepseek_v4.py to
    atom_parameter() so inference-vs-training grad behavior is controlled
    from a single place (ATOM_REQUIRES_GRAD env). 21 call sites.

Documentation
  - docs/environment_variables.md: new "Debug Dump" subsection documenting
    all 9 env vars + CLI usage.
  - .claude/skills/dump-bisect-debug.md (v3.0): full methodology rewrite
    in English with quick-start decision tree, phase-at-a-glance summary,
    "When to stop / accept divergence" guidance, V4 paper §3.3 batch
    invariance treatment as Phase 8. Includes Bug 11 isolation case study.
  - .claude/skills/atom-patterns.md: ATOM architecture index reference.

Verified by running CLI on existing E1 4xP3 dump:
    python -m atom.utils.debug_helper.compare slot-invariance \\
        --dir /app/logs_claude/deepseek_v4/dumps/bug11_e1
reproduces the layer-by-layer divergence table that informed Bug 11
isolation in notes/21_bug11_isolation.md.
Two fixes that surfaced from the same V4 load run:

1. atom/models/deepseek_v4.py — skip `gate.e_score_correction_bias`
   allocation for hash-routed layers (layer_id < n_hash_layers). V4 hash
   layers route via `tid2eid` lookup, not bias-corrected gate logits;
   the checkpoint has no `gate.bias` for those layers (only layers >= 3).
   Allocating it caused 3 spurious "param NOT loaded from checkpoint"
   warnings every load. Both call sites that read the attribute now use
   `getattr(self.gate, "e_score_correction_bias", None)` — moe.py already
   accepts None for `e_score_correction_bias`.

2. atom/model_loader/loader.py — add ckpt-side coverage check (the
   reverse direction of the existing atom-side check). Every
   `get_parameter() except AttributeError: continue/break` site now
   records `(orig_ckpt_name, rewritten_name)`; after the main loop the
   loader warns if any non-benign drops occurred. This catches the
   actionable bug class — `weights_mapping` / `WeightsMapper` rewrites
   the ckpt name to something the model has no slot for, silently
   throwing away real weight data — which the existing atom-side check
   misses entirely. Benign families (output_scale / kv_scale / inv_freq
   / weight_scale_2) are filtered so the warning is signal, not noise.

Verified on V4 load:
  - atom-side warning: 46/2519 -> 43/2516 (3 hash bias removed)
  - ckpt-side warning: 0 drops (mapping is clean for V4)
  - remaining 43 are all model.mtp.0.* (PR5 todo)
Per paper §3.6.1, the Compressor's per-request state cache holds
"uncompressed tail tokens + previous block as B-side overlap context"
(eq 11). Restructure ATOM's kv_state from a roll-on-decode two-segment
buffer into a single pos % STATE_SIZE ring buffer (STATE_SIZE = 2*ratio
for overlap CSA, ratio for HCA).

Kernel update_compressor_states (atom/model_ops/v4_kernels/state_writes.py):
- dst = pos % STATE_SIZE for every token; no segment switching, no roll
- Phase derived in-kernel from context_lens vs cu_seqlens_q; no IS_PREFILL
- Write mask: fresh prefill keeps [max(0, cutoff-ratio), seqlen) (B-side
  overlap + tail); decode/MTP writes every token

Compressor.forward:
- Drops decode-boundary roll (kv_state[:ratio] <- kv_state[ratio:])
- Reads A-side / B-side halves by block-id parity (comp_id % 2)

Metadata plumbing:
- V4 prepare_decode now populates var["context_lens"] + attaches to
  AttentionMetaData (parent prepare_prefill already did)
- Compressor / Indexer.forward accept required context_lens kwarg
- Wrapper has no positions-derived fallback for context_lens

Also bundles PR-A scaffolding:
- ATOM_V4_BACKEND env gate + per-layer bisect (envs.py, v4_backend_gate.py)
- CPU-mirror metadata (cu_seqlens_q_cpu, state_slot_mapping_cpu,
  start_pos_per_seq_cpu) to avoid per-seq .tolist()/.item() syncs
- v4_slot_indices -> state_slot_mapping rename (clearer vs paged-KV slot_mapping)
- swa_write Triton kernel integration (Phase 1a) under backend gate

Validates: 15/15 byte-equal kernel-vs-reference (prefill + decode + MTP);
simple_inference fast path TPOT 0.328-0.518s/tok matches pre-refactor
baseline (Apr 29 v4_simple_inference.log: 0.453s/tok).
valarLip added 3 commits May 2, 2026 03:47
Three independent batched-ops phases that share an outer-loop slot
in DeepseekV4Attention.forward:

- Phase 1: drop redundant per-seq state-cache reset loop. Fresh prefill
  never reads stale swa_kv (raw seq_kv used directly) nor stale
  Compressor state cache (fused_compress K-loop's is_padding=s<0 masks
  all is_state reads when prefix=0 → s = j_in_seq - K + 1 + k_static <
  0 for every is_state iteration). Verified GSM8K=0.96 on 25/50 samples.

- Phase 2a: vectorize per-seq window topk into one batched
  _build_window_topk_batched producing [total_tokens, win] padded with
  -1; loop body slices to per-seq width matching legacy
  _get_window_topk_idxs shape.

- Phase 2c: hoist SWA write out of per-seq loop into one batched
  swa_write kernel call. Pre-filter to last-win tokens per seq so the
  num_tokens parallel programs never collide on the same swa_kv ring
  slot (pos%win). Pre-fix, long-prefill (token_num > win) caused
  intra-seq write-race that dropped GSM8K from 0.88 to 0.32.

Per-seq dispatch loop still runs for Indexer + kv_sa packing — those
batched in follow-up phases (2b/2d/2e).
…se 2b-i)

Move the per-seq Indexer Compressor call into a single batched call before
the dispatch loop, using the same batched plan as the main CSA Compressor
(both have ratio=4, overlap=True, identical geometry). The Indexer's
internal kv_cache + state cache are populated for the whole batch in one
launch instead of bs separate launches per layer.

Indexer.forward gains a `skip_inner_compressor=True` flag the dispatch
loop sets after the hoist; legacy bs=1 plan path remains as the fallback
for any other caller.

Per-seq cost reduction: 64 layers × bs Compressor launches drop to
64 layers × 1 (each Compressor launch fires wkv/wgate Linear +
fused_compress_attn + update_compressor_states).

Verified GSM8K=0.94 ± 0.034 on 50 samples (matches baseline 0.94 — earlier
0.88 reading on 25 samples was within natural ±2-sample noise).
…-ii a)

Replace per-seq BF16 einsum (q ⊗ K → relu → weight → sum) with
aiter's fp8_mqa_logits kernel. Mathematically identical
(relu(QK*kv_scale) * weight summed over heads), but executes as a
single FP8 mma per row + post-row mask + topk. Mirrors V3.2's
sparse_attn_indexer_prefill kernel call.

Q is FP8-quantized inline (per-row 1x128 scale via get_hip_quant);
the scale is folded into `weights` along with softmax_scale and
1/sqrt(H), matching the V3 convention. K is FP8-quantized after
the per-seq gather. cu_starts=0, cu_ends=(pos+1)//ratio express
the V4 ratio-aware causal frontier directly through the kernel's
per-row KV range — no extra masking pass needed.

The legacy BF16 einsum path is retained behind `ATOM_V4_INDEXER_FP8=0`
for A/B testing.

Verified GSM8K=0.96 ± 0.028 on 50 samples (baseline 0.94 ± 0.034 — fp8
path is statistically at-or-above baseline; FP8 quant is closer to V4
training distribution than the current BF16 fallback).
valarLip added 3 commits May 2, 2026 10:24
…anup

Hoist all per-fwd, layer-invariant work from V4Attention.forward and
Indexer.forward_batched into the metadata builder, eliminating ~1200
per-layer torch.as_tensor H2D copies (~14 per pack call * 60+ layers,
~9 per Indexer call * 30 CSA layers, ~3 per gather call * 60+ layers)
in production fast path.

Builder-side helpers (atom/model_ops/attentions/deepseek_v4_attn.py):
- _attach_v4_per_fwd_meta: window_topk_batched + SWA write/positions/slots
- _build_v4_pack_meta_for_ratio: kv_sa + topk_flat index tensors per ratio
- _build_v4_indexer_meta: CSA Indexer batch_id/cu_committed/k/offset/is_prefill
  GPU tensors plus layer-invariant cu_starts/cu_ends/visible_end/width_mask/
  future_threshold derivations
- _build_v4_gather_indices: precomputed batch_ids/block_in_seq/slot_in_block
  for _v4_gather_compressed_batched
- _populate_state_slot_mapping: warmup fallback to slot 0 so dummy_run
  takes the normal forward path

V4Attention.forward / Indexer.forward_batched refactor:
- Read all per-fwd state once at top of forward (one get_forward_context
  call, direct attribute access — no nested getattr chains)
- Delete dummy_run special path entirely (synthetic 1-seq batch branch,
  sparse-attn placeholder branch, _v4_is_dummy_run helper, make_single_seq_plan
  fallback, indexer skip gate, compressor scatter dummy_run gate)
- Delete _v4_get_seq_metadata helper + cpu_meta plumbing (all dead)
- Delete slow path of _v4_build_sparse_inputs_batched (~263 LoC) and rename
  _v4_build_sparse_inputs_from_pack_meta -> _v4_build_sparse_inputs_batched
- Delete slow path of _v4_gather_compressed_batched + dead n_committed_per_seq
  / k_per_block params
- Indexer.forward_batched signature: drop cu_seqlens_q_cpu / start_pos_per_seq_cpu
  / win + dead k_per_seq_cpu return value
- Indexer.__init__: cache _fp8_quant_func / _weights_scale (was rebuilt per
  CSA layer)
- Promote V4_FORCE_UE8M0_QUANT / V4_USE_REF_QUANT / V4_AITER_HC_POST env-var
  reads to module-level constants
- Promote `from aiter import QuantType as _AiterQuantType` to module level
- Merge indexer.compressor.rotary_emb plumb into outer plumb (one less
  per-layer if-check)
- Rename per-fwd locals for clarity: sp_per_seq_gpu -> start_pos_per_seq_gpu,
  cu_q_gpu -> cu_seqlens_q_gpu, sp_cpu -> start_pos_per_seq_cpu, etc.

Removed APIs (unused after refactor):
- make_single_seq_plan (atom/model_ops/v4_kernels/{__init__,compress_plan}.py)

Verified:
- Smoke `1+2+3=?` returns `**6**`
- GSM8K-100 (ATOM_USE_TRITON_MOE=1, conc=16, fewshot=3): 0.96 +/- 0.020
Replace ~25 per-fwd `torch.as_tensor(np_arr)` H2D allocations in V4
metadata builder with pre-allocated CpuGpuBuffer pool. Fixes GPU
pointers across forwards — prerequisite for CUDAGraph capture (CG-B).

Buffer pool allocated once in __init__ (~80 MB at typical config).
All builder helpers now write via `_stage(name, arr)` which does
`buf.np[:n] = arr; copy_to_gpu(n)` and asserts capacity.

Coverage:
- _attach_v4_per_fwd_meta: 4 buffers (start_pos / token_num / write_indices / state_slot)
- _populate_state_slot_mapping: 1 buffer (groups)
- _build_v4_indexer_meta: 6 buffers (batch_id / cu_committed / n_committed / k / offset / is_prefill)
- _build_v4_gather_indices: 3 buffers x 3 callers (indexer / csa_dc / hca_dc)
- _build_v4_pack_meta_for_ratio: 11 buffers per kind (csa/hca/dense)

Forward path unchanged. Validated GSM8K-100 = 0.95 ± 0.022 (baseline 0.96).
Prepares V4 backend for CUDAGraph capture/replay (still gated behind
--enforce-eager removal in a follow-up). All capture-required GPU pointer
addresses are now stable across forwards.

Changes:

- Kernels gain fixed-grid + sentinel-mask path: fused_compress_attn,
  update_compressor_states, swa_write all skip rows whose position == -1,
  so the wrapper can launch at full plan/buffer capacity (CUDAGraph
  capturable) regardless of how many tokens this fwd actually writes.

- fused_compress_attn / update_compressor_states accept strided kv/score
  inputs (drop the defensive .contiguous() copies in callers); only inner
  column stride is required to be 1.

- fused_compress_attn gains an out= param for caller-provided pre-allocated
  output buffer (used in CUDAGraph path to keep output address stable);
  eager path still allocates per-call.

- make_compress_plans accepts plan_buffers dict of pre-allocated CpuGpuBuffer;
  writes into them and sentinel-fills tail rows. Empty-fwd path also fills
  buffers so capture-time addresses match replay.

- DeepseekV4AttentionMetadataBuilder._alloc_v4_metadata_buffers pre-allocates
  v4_compress_plan_{ratio} / v4_write_plan_{ratio} CpuGpuBuffers and per-kind
  v4_{csa_main,csa_idx,hca_main}_compress_out BF16 tensors; build_kv_cache_tensor
  binds the latter to each Compressor module's compress_out attribute.

- build_for_cudagraph_capture replaces the stub: synthesizes a decode batch
  at start_pos=window_size, runs through prepare_decode helpers
  (_attach_sparse_layout_metadata + _attach_v4_per_fwd_meta +
  _build_compress_plans), returns (AttentionMetaData, Context) wired to
  forward_vars buffers.

- DeepseekV4Model.forward returns hidden_states (post hc_head + norm)
  instead of full vocab logits; DeepseekV4ForCausalLM.compute_logits
  applies head.get_logits. Required so the CUDAGraph output buffer is
  sized to hidden_size, not vocab_size (~18x smaller, also matches the
  ATOM standard contract used by other models).

- Compressor gains compress_out attribute (set by builder; threaded
  through fused_compress_attn as out=).

- kv_indptr stub buffer added to forward_vars (touched unconditionally
  by the global capture loop; V4 doesn't use it for its own kernels).

Misc:

- Hoist 3 lazy `from atom.model_ops.quant_v4 import act_quant_inplace
  as _v4_aqi` imports to the top-level import block.

- Gate `act_quant_inplace(kv[..., :-rd], 64, scale_fmt)` on
  _V4_USE_REF_QUANT (default off). Previously unconditional; the env
  gate already exists for the matching qr/x quant pair, so making this
  consistent. GSM8K-100 = 0.99 with the gate (no regression vs prior
  unconditional path which also produced 0.99 in recent runs).

Validation: GSM8K-100 = 0.99 ± 0.01 (eager mode). CUDAGraph end-to-end
(without --enforce-eager) still pending — needs further capture-loop work.
valarLip added 7 commits May 3, 2026 02:35
…perf nits

Linear projection fusions (FP8/BF16, zero-copy split downstream):
- attn.wq_a + attn.wkv → attn.wqkv_a (MergedReplicatedLinear, FP8)
- compressor.wkv + compressor.wgate → compressor.wkv_gate (BF16, otype=fp32)
- shared_experts.w1 + w3 → shared_experts.gate_up_proj (MergedColumnParallelLinear)
- packed_modules_mapping routes disk shards via standard ATOM loader
- Compressor and update_compressor_states accept strided kv/score inputs

MoE refactor:
- Drop use_fused/Gate/_torch_moe_forward/toy/dummy paths
- Split forward into routed_expert_forward / combine_outputs /
  single_stream_moe_forward / dual_stream_moe_forward
- Dispatch via torch.ops.aiter.maybe_dual_stream_forward (Dynamo barrier)
- Extract maybe_dual_stream_forward into atom/model_ops/dual_stream_moe.py
  (shared with V2; V2 inline implementation removed)
- Direct routed/shared dtype check for shared-expert fusion gating
  (V4 has FP4 routed + FP8 shared; the global-vs-shared helper returns
  the wrong answer because shared==global but routed!=global)

Custom op fix: dual_stream_moe declares mutates_args=() (the V2-original
mutates_args=["hidden_states"] is a false-mutation declaration — op returns
a fresh tensor, never writes to input — and would mislead AOT/functionalization
into inserting defensive clones).

Aiter kernel refs hoisted:
- _V4_AITER_HC_POST env gate removed; mhc_pre/mhc_post dim+presence check
  resolved once in Block.__init__ to self._mhc_pre / self._mhc_post
- per-fwd path is just `if self._mhc_pre is not None:` attribute lookup

Shape contracts (ATOM 2D-flat ragged-batch convention):
- All forward signatures get inline shape annotations
  (e.g. `x: torch.Tensor,  # [num_tokens, dim]`)
- Drop legacy [B, S, ...] 4D paths in Block.hc_pre/hc_post, ParallelHead.hc_head,
  MTPBlock.forward, ParallelHead.get_logits
- Drop input_ids.dim()==2 normalization in DeepseekV4Model.forward
- Compressor.forward asserts 2D, drops defensive 3D-squeeze

Code organization:
- _segment_indices and _build_window_topk_batched moved from deepseek_v4.py
  to attentions/deepseek_v4_attn.py (only callers are the metadata builder);
  removes two cross-file lazy imports
- _AiterQuantType alias removed (atom.config.QuantType is the same pybind class)
- Stale # noqa: F401 pragmas dropped (sparse_attn_v4, v4_kernels imports
  are all actively referenced)
- ruff full-pass on V4 + V2 + dual_stream_moe + V4 attn

Indexer.forward_batched post-topk path:
- 10 GPU launches + 1 full_like alloc → 7 launches + 0 allocs
- (topk_local < 0) | future_mask is equivalent to width_mask | future_mask
  (fp8_mqa_logits masks out-of-seq logits to -inf, so topk_local < 0
  only fires on width-masked slots)
- masked_fill_ in-place over (topk_local + offset) replaces full_like + where

Removed redundant ops in hot path:
- vestigial unsqueeze(0)→squeeze(0) in Indexer.forward_batched,
  DeepseekV4Attention.forward, _v4_build_sparse_inputs_batched
- .type_as(x) on aiter mhc_post path (out.dtype == residual.dtype == x.dtype)
- unused `ratio = self.compress_ratio` local in Indexer.forward_batched

Validation: GSM8K-100 num_fewshot=3 = 0.98 ± 0.014 (baseline 0.97 ± 0.017,
within stderr).
Convert v4_csa_idx_kv from BF16 to FP8+scale layout following V3.2
sparse_attn_indexer pattern. Pool size for the indexer cache drops 44%
(BF16 1.07GB -> FP8+scale 0.55GB at NB=4096).

Pool layout
- shape: [n_csa, NB, k1_csa, aligned_dim=144] dtypes.fp8 (layer-major
  so pool[pos] is contig per CSA layer)
- per row: [head_dim] FP8 + 4-byte fp32 scale, 16B-aligned

Write path (Compressor.forward, idx_slot_mapping is not None)
- Compressor gains optional idx_slot_mapping (int64). When set, the
  fused-compress kernel skips its BF16 scatter and we instead call
  indexer_k_quant_and_cache(out, kv_cache, slot_mapping, head_dim,
  scale_fmt) to FP8-quantize+write each compress row in one shot.
- Slot mapping built host-side in _build_indexer_compress_slot_mapping
  from csa_compress_plan_cpu + block_tables (no extra GPU->CPU copy
  thanks to the new compress_plan_cpu field on CompressPlan).

Read path (Indexer.forward_batched)
- cp_gather_indexer_k_quant_cache(kv_cache, k_fp8, k_scale.view(fp8),
  block_tables, cu_committed_gpu) does paged-gather + split into
  separate (FP8, scale) buffers in one launch -- no per-row index list,
  no online quant.
- Then fp8_mqa_logits over [Q_fp8, K_fp8, kv_scales=k_scale, weights]
  drops the legacy gather_compressed + BF16 einsum path entirely.

Builder side
- _build_v4_indexer_meta gains csa_compress_plan_cpu param; produces
  compress_slot_mapping_gpu (int64, kernel sig is int64_t*) and
  cu_committed_gpu (int32, kernel sig is int32_t*).
- "indexer" gather buffer set removed -- cp_gather_indexer_k_quant_cache
  consumes block_tables + cu_seq_lens directly.
- CompressPlan grows compress_plan_cpu: np.ndarray | None for the same
  reason: builder needs the plan rows host-side to derive slot_mapping
  without an extra D2H sync.

Shape contract gotcha (root cause of an OOM-fault hunt)
- Indexer.kv_cache binding MUST keep [NB, k1_csa, aligned_dim] (3D,
  block_size dim explicit). Flattening to [NB*k1, 1, aligned_dim] makes
  cp_gather_indexer_k_quant_cache infer block_size=1 from shape[1],
  which then OOB-indexes block_table. Matches V3.2's [num_blocks,
  block_size, head_dim] layout (deepseek_v2.py:1049).
- Write side (indexer_k_quant_and_cache) is shape-agnostic -- uses
  slot_mapping flat index -- so the symmetric 3D binding for the inner
  Compressor is for clarity, not correctness.

Validation
- simple_inference V4-Pro tp=8 fp8 enforce-eager: all 4 prompts produce
  coherent output (1+2+3=**6**, prime list, Chinese long-form).
- GSM8K-100 num_fewshot=3: flexible-extract / strict-match both 0.96
  +/- 0.0197 (baseline 0.97 +/- 0.017, within tolerance).
…deepgemm

Three changes folded into one commit (validated together GSM8K-100=0.97 ±
0.0171, baseline 8ab1367 also 0.97):

1. **preshuffle on indexer write+read** (`indexer_k_quant_and_cache` +
   `cp_gather_indexer_k_quant_cache`): MFMA 16x16 tile-aware FP8 cache
   layout, matches V3.2/PR #658 convention. Required by
   `deepgemm_fp8_paged_mqa_logits` for `KVBlockSize > 1`.

2. **split `Indexer.forward_batched` into prefill/decode helpers**: common
   path (Q proj+RoPE+rotate+FP8 quant, weights computation) stays in
   `forward_batched`; dispatch via `context.is_prefill` to
   `_score_topk_prefill` (cp_gather + fp8_mqa_logits, eager-only —
   variable `total_committed` shape) or `_score_topk_decode` (deepgemm,
   fixed-shape `[bs*next_n, max_model_len_idx]`). Mixed batches go through
   prefill path. `_post_process_topk` shared, branches on `is_decode` to
   skip the seq_base subtraction (decode topk indices are already seq-local;
   prefill indices are global flat positions across cu_committed).

3. **decode helper uses `deepgemm_fp8_paged_mqa_logits`**: reads paged FP8
   cache directly via 4D view `[NB, k1_csa=32, 1, aligned_dim=144]`, writes
   into pre-`-inf`-filled logits buffer (cols beyond per-seq context_lens
   stay -inf so PyTorch topk doesn't pick garbage). `width_mask`
   masked_fill handles per-token k_per_token trimming. CUDAGraph-friendly
   shapes — for Phase B/C buffer pre-allocation + capture path.

Builder: expose `n_committed_per_seq_gpu` (int64, [bs]) in indexer_meta —
no new H2D, just lifts the existing staged tensor into the return dict for
deepgemm context_lens consumption.

Init-time hoist: `Indexer._max_model_len_idx = args.max_seq_len //
compress_ratio` — deepgemm output column count, constant per layer.

Composition validated standalone (test_decode_deepgemm_vs_fp8_mqa.py:
100% top-K overlap with `cp_gather + fp8_mqa_logits` baseline given
`-inf`-init buffer). Numerical round-trip with cache_stride=144 +
preshuffle validated (test_indexer_roundtrip_numerical.py: cos≥0.9995
across all num_tokens / dispatch branches).

Net: +119 / -20 LoC. Phase B/C (decode logits buffer pre-alloc +
build_for_cudagraph_capture) tracked separately.
Replaces the torch.topk + -inf fill path in `Indexer._score_topk_*`
with aiter `top_k_per_row_decode/prefill` (radix kernel, parametric k).
Both paths emit a uniform [total_tokens, index_topk] int32 layout.

  _score_topk_decode (CG-friendly path):
    - Pre-allocated [max_bs, index_topk] int32 indices buffer in builder.
    - Pre-allocated [max_bs, max_model_len_idx] fp32 logits buffer.
    - Drop `fill_(-inf)`: top_k_per_row_decode honors n_committed_per_seq
      per row, so logits cells past valid range are never read.
    - Drop torch.topk + .to(int32) cast.

  _score_topk_prefill (eager-only path):
    - Drop torch.topk + dynamic-`max_k` shape; emit
      [total_tokens, index_topk] via top_k_per_row_prefill(k=index_topk),
      kernel writes -1 sentinels in tail cols.
    - Per-fwd torch.empty for indices (prefill total_tokens dynamic).

  Builder _build_v4_indexer_meta:
    - v4_indexer_n_committed_per_seq buffer i64 -> i32 (kernel arg dtype).
    - Add v4_indexer_decode_logits and v4_indexer_decode_topk_indices
      forward_vars buffers.
    - width_mask collapses to uniform [total_tokens, index_topk] bool.
    - Drop max_k from returned dict; empty-batch guard now keys on
      total_committed == 0.

  Builder _build_v4_pack_meta_for_ratio:
    - compress_topk_src stride is `index_topk` for both paths (was the
      dynamic max_k = max(k_per_seq), which assumed prefill's
      torch.topk(max_k) output shape).

  _post_process_topk:
    - Input contract changes to [total_tokens, index_topk] uniform layout.

Depends on ROCm/aiter#3012 (exposes `k` kwarg on top_k_per_row_decode /
top_k_per_row_prefill); existing aiter without that PR will silently
ignore the kwarg and run with k=2048 (still correct, but allocates an
oversized output buffer).

Validation:
  - aiter kernel parity at v4 shapes (k=1024, varying bs/ctx) - all OK.
  - GSM8K-100 num_fewshot=3 eager: 0.97 / 0.97 (stable vs 0.96 baseline).
Enable CUDAGraph capture for DeepSeek-V4 (Pro / non-Pro) sparse decode.
Final config validated: cudagraph-capture-sizes [1,2,4,8,16,32,64] +
max-num-seqs 64, GSM8K-50 = 0.98.

== Approach ==

Upstream V4 reference materializes "indexer-selected K's" into a
per-fwd dense `kv_flat_sa` tensor whose shape depends on device-side
data — this prevents CUDAGraph capture. ATOM replaces it with a paged
interface (single base pointer + packed-cumsum kv_indptr + kv_indices)
backed by per-layer unified BF16 pool, plus a dedicated triton kernel
that handles V4-specific attn_sink + page_size=1.

== Components ==

1. New triton kernel `sparse_attn_v4_paged_decode`
   (atom/model_ops/v4_kernels/paged_decode.py): page_size=1 sparse
   attention with attn_sink, API aligned with V3.2 mla_decode_fwd
   naming. 3 unit tests bit-exact vs reference.

2. Per-layer `unified_kv` pool (Phase A,
   atom/model_ops/attentions/deepseek_v4_attn.py allocator):
   physically merges SWA ring buffer and compressor paged KV into one
   contiguous BF16 tensor — kernel uses one base pointer, every index
   (SWA / CSA / HCA) is a row offset.

3. Per-fwd paged-decode index construction (Phase B,
   `_attach_v4_paged_decode_meta`): builds 3 kv_indptr cumsums (SWA
   uniform stride, CSA / HCA packed) + scatters SWA window prefix +
   fully populates HCA compress section. All in stable forward_vars
   buffers, no device-data-dependent allocation.

4. CSA per-layer translation kernel `csa_packed_write` (Phase D2,
   atom/model_ops/v4_kernels/csa_packed_write.py): fixed-grid triton
   kernel that translates the indexer's sequence-local topk into
   physical paged offsets and packs them into the per-token CSA
   indices section. Replaces upstream's dynamic-shape scatter.
   Per-seq n_committed + per-token batch_id API (no per-token
   valid_count alias). 7 unit tests pass.

5. Phase-C V4Attention.forward translation (`_fill_csa_paged_compress`,
   atom/models/deepseek_v4.py): per-layer indexer → paged-offset
   translation via fancy indexing on stable per-fwd metadata.

6. Phase-E dispatch (atom/models/deepseek_v4.py:1683): when
   `is_pure_decode` (uniform tokens-per-seq AND no fresh prefill,
   doc §7.4), each V4 attention layer dispatches to
   `sparse_attn_v4_paged_decode`; prefill / mixed batches keep the
   ragged_varlen path.

7. Phase-F MTP-1 capture (`build_for_cudagraph_capture`):
   capture-time uses `max_q_len = 1 + max_spec_steps` and
   physically-replicated packed indptr so MTP-1 fits the same
   captured kernel grid as non-MTP decode.

8. CG-padding sentinel protocol: model_runner pads decode batches up
   to the captured graph_bs; captured kernels still iterate
   `graph_bs * max_q_len` slots at replay. Builder must sentinel-pad
   per-fwd metadata to padded_total_tokens or padded slots read stale
   buffer entries → kv_indices OOB → GPU memory access fault. Two
   sentinel conventions:
     • per-token tensors (batch_id_per_token, swa_write_indices)
       padded with -1; consumer kernels skip on `bid<0` / `src_id<0`
     • indptr cumsums (kv_indptr_{swa,csa,hca}) padded by repeating
       the last cumsum value, making `kv_len = indptr[t+1]-indptr[t] = 0`
       so the kernel inner loop bails without reading kv_indices

   `_fill_csa_paged_compress` clamps batch_id/block_idx before
   PyTorch fancy indexing so padded slots can never trip the GPU
   gather; the bogus paged_compress they produce is dropped by
   csa_packed_write's `bid<0` skip.

9. swa_write rewritten to take stable forward_vars buffers directly
   (full kv + write_indices + positions + batch_id_per_token +
   state_slot_per_seq). The previous `kv[swa_write_indices].contiguous()`
   fancy-index allocated a fresh tensor per call — captured-region
   alloc that we eliminated. Per-token state slot is looked up
   inside the kernel via `state_slot_per_seq[batch_id_per_token[src_id]]`
   — single-per-token-mapping principle (no per-token slot alias).
   5 unit tests pass.

10. Typed AttentionMetaData_DSV4 dataclass replaces ad-hoc setattr on
    the shared base. attn_metadata fields drop the `v4_` prefix
    (subclass provides namespace); buffer keys keep the prefix
    (forward_vars dict is shared across backends). Field-level
    docstrings record shape + dtype.

11. Per-fwd H2D dedup. Builder used to stage four functionally
    redundant tensors per fwd (multiplied by ~62 V4-Pro layers per
    decode step):

      v4_meta_state_slot_i32        ↔ v4_meta_state_slot_groups
      v4_csa_valid_count_per_seq    ↔ v4_indexer_n_committed_per_seq
      v4_indexer_batch_id_per_token ↔ v4_batch_id_per_token

    Consolidated into single-source-of-truth fields on attn_metadata:

      state_slot_mapping        int32 [bs]
      n_committed_csa_per_seq   int32 [bs]   (rename + drop builder
                                              `np.minimum(., index_topk)`
                                              clamp; csa_packed_write
                                              kernel masks `(k <
                                              n_committed) & (k <
                                              index_topk)`)
      batch_id_per_token        int64 [mnbt] (was int32; one buffer
                                              now satisfies both
                                              PyTorch fancy-index and
                                              triton kernels)

    `_attach_v4_per_fwd_meta` produces these unconditionally;
    `_build_v4_indexer_meta` reads them instead of re-staging. Required
    reordering `_attach_v4_per_fwd_meta` BEFORE
    `_attach_sparse_layout_metadata` in all three entry points
    (prepare_decode, prepare_prefill, build_for_cudagraph_capture).

    Net: -4 H2D copies / fwd, -32 KB peak metadata staging memory.

== CG-safety dtype rule (codified) ==

Host-side `.long()` / `.to(int64)` cast on the per-fwd path is
forbidden: it allocates a fresh int64 tensor whose data_ptr varies
across forwards, so a captured graph reads stale memory at replay
(→ GPU memory access fault). Anything a downstream PyTorch fancy-index
needs as int64 must be staged as int64 at the source. That's why
batch_id_per_token graduated to int64 — shared across triton kernels
(don't care about dtype) AND PyTorch fancy-index (requires int64).

== Other refactors landed alongside ==

- module_dispatch_ops: `dual_stream_moe.py` deleted (replaced by
  module-level dispatch helpers used uniformly by V3 / V3.2 / V4)
- deepseek_v2.py: minor adjustments to share helpers with V4 codepath

== Validation ==

  | config                                              | GSM8K          |
  |-----------------------------------------------------|----------------|
  | eager + max_num_seqs=16                             | 0.97 (30)      |
  | CG [1,2,4]   + max_num_seqs=16                      | 1.00 (50)      |
  | CG [1,2,4]   + max_num_seqs=64                      | 0.94 (50)      |
  | CG [1,2,4,8] + max_num_seqs=16                      | 0.93 (30)      |
  | CG [1,2,4,8,16,32,64] + max_num_seqs=64             | 0.98 (50)      |
  | unit tests (paged_decode, csa_packed_write, swa_write_v2) | all pass |
Conflict resolution:
- atom/utils/selector.py: combine both new params — keep `use_v4`
  (V4 backend dispatch from HEAD) and `use_sglang` (sglang GDN routing
  from main). V4 takes precedence over MLA → GDN → default.
- atom/model_engine/llm_engine.py: keep HEAD's frozenset (adds
  `deepseek_v4` to per-req-cache model types).

Verified all 5 dispatch branches (V4, MLA, GDN+sglang, GDN+vllm,
default) resolve to the correct backend class.
valarLip added 2 commits May 5, 2026 03:49
…rnel

Replace `_fill_csa_paged_compress`'s 7-op PyTorch chain (// + % + 2× clamp +
fancy index + arithmetic + where) plus the standalone `csa_packed_write`
triton kernel with one fused `csa_translate_pack` kernel that does indexer
topk → paged-offset translation and bounded packed write entirely in
registers.

Per V4-Pro fwd (31 CSA layers):
  - Eliminates ~155 transient [T, index_topk] tensor allocs (5+/layer)
  - Collapses CG graph nodes from 7-8/layer to 1/layer
  - One triton launch per CSA layer instead of 7 PyTorch ops + 1 launch
  - In-kernel `bid<0` sentinel + `blk_safe = clamp(blk_idx, 0, mnbps-1)`
    handles CG-padded slots without a separate `where(..., -1)` pass

Correctness:
  - paged_decode strict-slices `kv_indices_csa[indptr[t] : indptr[t+1]]`
    whose length is exactly `window_size + n_committed_csa[bid]`. The
    fused kernel writes only `[0, min(valid_k, index_topk))`; the tail is
    never read downstream — the prior `where(..., -1)` debug fill was
    observability-only.
  - Caller passes RAW `n_committed_csa_per_seq` (= ctx_len // 4); kernel
    clamps internally via `(k < valid_k) & (k < index_topk)` mask. Raw
    value is required because the indexer also reads this same buffer.

API:
  - csa_translate_pack(): scalar config args (swa_pages, window_size,
    csa_block_capacity) are keyword-only via `*` — prevents transposed
    positional args silently corrupting offsets.

Validation:
  - 10 unit tests vs PyTorch reference, all bit-exact:
    decode_fast, mtp1, short_valid, zero_valid, cg_padded, mixed,
    n_committed_overflow (valid_k > index_topk → in-kernel clamp),
    empty_T (T=0 early return), all_sentinel (real_T=0 fully padded),
    v4pro_realistic (bs=64, MTP-1, index_topk=1024, mnbps=8192).
  - eager GSM8K-50 = 0.94 (matches V4 baseline).
  - wider CG `[1,2,4,8,16,32,64]` + max-num-seqs=64:
    GSM8K-50 = 0.98, GSM8K-100 = 0.98 (matches pre-fusion cb7f84f).
  - No fault, no hang, no precision regression.

Cleanup (Fix-then-sweep):
  - Delete obsolete csa_packed_write.py + its unit test
  - Update 7 stale `csa_packed_write` references in deepseek_v4_attn.py
    docstrings/comments to point at the fused kernel

Net: 4 files, +32 / −213 production LOC.
Replaces the legacy ragged_varlen prefill path with a CUDAGraph-friendly
dual-source `sparse_attn_v4_paged_prefill` kernel and removes the dead
infrastructure left over from the old packed-input dispatch.

New attention dispatch (V4Attention.forward):
  - is_pure_decode=True  -> swa_write -> sparse_attn_v4_paged_decode
                            (single source = unified_kv ring)
  - is_pure_decode=False -> sparse_attn_v4_paged_prefill -> swa_write
                            (dual source: prefix from unified_kv,
                             extend from per-fwd kv tensor)

  swa_write order differs per branch:
    - decode:  current decode token must read its own K from the ring
               BEFORE attn fires
    - prefill: prior-chunk SWA prefix must read prior data BEFORE the
               current swa_write overwrites ring slots `pos % win`
               (chunked prefill correctness)

New kernel:
  - sparse_attn_v4_paged_prefill triton kernel + PyTorch reference +
    unit test (atom/model_ops/v4_kernels/paged_prefill.py)
  - Indexes (a) `unified_kv` for prefix region (per-ratio: SWA history +
    optional CSA topk / HCA all-committed), (b) per-fwd `kv` for extend
    region (in-chunk SWA tail; layer-invariant, shared across ratios)

Indexer / csa_translate_pack:
  - `_score_topk_{prefill,decode}` return RAW seq-local `topk_in_seq`
    directly (drop `_post_process_topk`); prefill path inlines the
    `topk_global - seq_base` subtraction in the kernel
  - csa_translate_pack signature: scalar `window_size` -> per-token
    `skip_prefix_len_per_token` array; kernel adds `topk >= 0` write
    mask so indexer's -1 sentinels in tail cols are skipped (prefill
    needs this because per-token visibility cap < n_committed_csa[seq])
  - `_fill_csa_paged_compress` takes raw topk arg + dispatches decode
    vs prefill buffer pair internally

Builder:
  - New `_build_paged_prefill_meta` constructs extend + per-ratio prefix
    indices/indptrs vectorised via `_segment_indices`
  - `_attach_v4_per_fwd_meta` builds `window_topk` as a builder-internal
    local (no longer exposed on attn_metadata) and passes it to
    `_attach_v4_paged_decode_meta`
  - `_attach_sparse_layout_metadata` -> `_attach_v4_indexer_meta`
    (rename for name-matches-function — body now only attaches
    indexer_meta after the sparse_layouts builder body was deleted)

Numerics fix (attention sink finalization):
  Online softmax + sink integration treats sink as a virtual extra K
  with V_sink=0. After the main K loops (m_i, l_i, acc) are in m_i
  frame; sink may shift max to m_final = max(m_i, sink), so BOTH l_i
  (denominator) AND acc (numerator) must be rescaled by
  alpha = exp(m_i - m_final) to switch frame. Previous code only
  rescaled l_i, leaving acc in the m_i frame while dividing by
  l_final (m_final frame) — invisible in V4-Pro production
  (per-token K count ~1000+ -> max(scores) >> sink -> alpha=1) but
  manifests in unit tests with small K counts. Fix applied uniformly
  to all 5 sink finalization sites (max|triton-ref| in sink-dominant
  case dropped from 1.375 to 7.8e-3, BF16 noise floor).

Dead-code removal:
  sparse_attn_v4.py:
    sparse_attn_ragged_varlen + 3 helpers (_triton_kernel / _triton /
    _torch dispatcher) + _bucket_topk
  deepseek_v4.py:
    _v4_build_sparse_inputs_batched (packed-input materialization)
    _v4_gather_compressed_batched (its only consumer)
    kv_compress_batched local — Compressor return value unused; the
      scatter side-effect into unified_kv is what matters
    Redundant nested `if ratio == 4:` inside `if self.indexer is not None:`
      (Indexer is None iff ratio != 4)
    Dead `indexer_topk_batched = None` init
  deepseek_v4_attn.py:
    _build_v4_pack_meta_for_ratio (only fed _v4_build_sparse_inputs_batched)
    _build_v4_gather_indices (only fed the deleted pack_meta builder)
    AttentionMetaData_DSV4.sparse_layouts field
    AttentionMetaData_DSV4.window_topk_batched field (now local)
    AttentionMetaData_DSV4.cu_seqlens_q_cpu field — set 3 times, never read
    25 CpuGpuBuffer allocations (v4_*_sparse_topk_*, v4_*_pack_*,
      v4_*_dc_gather_*, v4_indexer_k_per_token)
    _build_v4_indexer_meta dead locals: offset_per_seq_np / k_per_seq_cpu
      / k_per_token_np / k_per_token_gpu (computed + staged but never
      surfaced in the indexer_meta dict)
    Indexer dead fields: width_mask_gpu, offset_per_token_gpu,
      is_prefill_per_token_gpu, future_threshold_gpu + their CpuGpuBuffer
      allocations
    `_attach_v4_per_fwd_meta` / `_build_v4_indexer_meta`: drop three
      `bs == 0 / total_tokens == 0 / empty-write` early-None exits.
      warmup + dummy_run + CG capture all guarantee bs >= 1 and
      total_tokens >= 1; document the contract instead.
    Consumer guards: matching `if swa_write_indices is not None`,
      `v4_indexer_meta is not None` ternary, and
      `if indexer_meta is None: return torch.full(-1)` guards
    `if indexer_meta["total_committed"] == 0: return torch.full(-1)`
      host-side branch — would freeze into the CUDAGraph at capture
      time; the hot path already handles n_committed=0 natively
  Indexer:
    Redundant `.float()` in `weights = self.weights_proj(...) * q_scale
    * scale` chain — q_scale is fp32, broadcast multiply auto-promotes
    the chain to fp32. Aligned with v2 deepseek implementation.

Other refactors:
  - V4Attention.forward gates on is_dummy_run (matches attention_mha.py)
    since unified_kv binding happens after warmup
  - Remove @support_torch_compile from DeepseekV4Model. Tripped
    AOTAutograd's PropagateUnbackedSymInts pass on `positions.view` at
    --level 3 PIECEWISE because V4Attention.forward pulls many
    forward_context fields directly into the graph. Re-enabling will
    require wrapping V4Attention.forward as a custom op (V2 sidesteps
    this via aiter.unified_attention_with_output in splitting_ops; V4
    has no such op yet).
  - Move V4 hash MoE input_ids side-channel write from
    ModelRunner.run_model to DeepseekV4ForCausalLM.forward. Compiled
    model is opaque to ModelRunner; setting it at the model entrypoint
    makes any caller (production runner, warmup, benchmarks) get
    correct hash routing without separate setup.
  - DeepseekV4Model.forward: drop `**model_kwargs`, drop
    `if positions is None` arange fallback (positions is now required).
  - MoE: drop `_hash_input_ids` NNModule attribute
    (torch.compile silently dropped its mutation). `_hash_topk` reads
    from `forward_context.context` instead. Remove `input_ids` kwarg
    from MoE.forward / Block.forward / MTPBlock.forward call chain.
  - V4Attention.__init__: plumb rotary_emb into compressor/indexer here
    instead of lazily in forward (NNModule setattr inside compiled
    forward = graph break).
  - `_apply_rotary_emb`: replace `Tensor.unflatten(-1, (-1, 2))` with
    `reshape(*shape[:-1], -1, 2)` (equivalent shape transform, simpler).
  - `_V4RoPE.forward` view variable cleanup.
  - model_runner cudagraph_capture_sizes: replace hard assert with
    warn-and-filter so misconfigured `--cudagraph-capture-sizes
    [N>max_num_seqs]` drops oversized entries instead of crashing.
  - Add `Context.input_ids: Optional[torch.Tensor]` field (read by
    callbacks inside Dynamo-opaque custom ops that can't receive it
    via the FusedMoE fixed-signature custom_routing_function).
  - Rename `seqlen_total` -> `num_tokens` in V4Attention.forward
    (the value is ragged-batch flat token count, not per-seq length).

Validation (DeepSeek-V4-Pro tp=8 fp8 --level 0 + CG bs=[1..32]):
  - csa_translate_pack unit test: 12/12 PASS
  - sparse_attn_v4_paged_prefill unit test: 9/9 PASS (max diff 7.8e-3)
  - ruff check F401/F841: all clean
  - GSM8K-100 nshot=3 (4 runs across the cleanup pass):
      0.93 / 0.96 / 0.98 / 0.95  -> avg 0.955, baseline 0.95 met
@sunway513
Copy link
Copy Markdown
Collaborator

Full GSM8K 1319-Problem Accuracy Validation -- PR #650 (cb7f84f)

TL;DR: PR #650 achieves 0.9522 (flexible-extract) / 0.9507 (strict-match) on the full GSM8K test set (1319 problems, 5-shot). This is a +25pp lift over the pre-650 baseline (0.70) and matches SGLang B300 0.96 within 1pp.

Results

Metric Score
flexible-extract 1256/1319 = 0.9522
strict-match 1254/1319 = 0.9507
errors 0/1319
eval latency 1755.6s (1.33s/req avg, conc=4)

Cross-Reference

Source GSM8K Score Notes
This run (full 1319) 0.9522 Ground truth
SGLang B300 0.96 Prior full-set eval
Commit body claim 0.98 limit=50, favorable variance
ATOM pre-650 0.70 Issue baseline

Validation Config

  • Node: MI355X x8 (mi355-gpu-15)
  • Server: --kv_cache_dtype fp8 -tp 8 --max-num-seqs 64 --max-num-batched-tokens 4096 --max-model-len 4096 --gpu-memory-utilization 0.85 --cudagraph-capture-sizes [1,2,4,8,16,32,64]
  • Env: ATOM_USE_TRITON_MOE=1
  • aiter deps: 3 cherry-picks (a6bb4993 #2879, a0f25393 #2980, fefd5136 #3012)
  • Model: DeepSeek-V4-Pro (HF weights)

Conclusion

PR #650 DeepSeek V4-Pro model class + CUDAGraph infra + batched compressor/indexer + FP8 CSA cache is accuracy-validated at production grade. The +25pp lift confirms the architectural rewrite is delivering. Ready for merge review.


Validated by: Peng Sun (sunway513) via automated GSM8K evaluation pipeline, 2026-05-05

1 similar comment
@sunway513
Copy link
Copy Markdown
Collaborator

Full GSM8K 1319-Problem Accuracy Validation -- PR #650 (cb7f84f)

TL;DR: PR #650 achieves 0.9522 (flexible-extract) / 0.9507 (strict-match) on the full GSM8K test set (1319 problems, 5-shot). This is a +25pp lift over the pre-650 baseline (0.70) and matches SGLang B300 0.96 within 1pp.

Results

Metric Score
flexible-extract 1256/1319 = 0.9522
strict-match 1254/1319 = 0.9507
errors 0/1319
eval latency 1755.6s (1.33s/req avg, conc=4)

Cross-Reference

Source GSM8K Score Notes
This run (full 1319) 0.9522 Ground truth
SGLang B300 0.96 Prior full-set eval
Commit body claim 0.98 limit=50, favorable variance
ATOM pre-650 0.70 Issue baseline

Validation Config

  • Node: MI355X x8 (mi355-gpu-15)
  • Server: --kv_cache_dtype fp8 -tp 8 --max-num-seqs 64 --max-num-batched-tokens 4096 --max-model-len 4096 --gpu-memory-utilization 0.85 --cudagraph-capture-sizes [1,2,4,8,16,32,64]
  • Env: ATOM_USE_TRITON_MOE=1
  • aiter deps: 3 cherry-picks (a6bb4993 #2879, a0f25393 #2980, fefd5136 #3012)
  • Model: DeepSeek-V4-Pro (HF weights)

Conclusion

PR #650 DeepSeek V4-Pro model class + CUDAGraph infra + batched compressor/indexer + FP8 CSA cache is accuracy-validated at production grade. The +25pp lift confirms the architectural rewrite is delivering. Ready for merge review.


Validated by: Peng Sun (sunway513) via automated GSM8K evaluation pipeline, 2026-05-05

ZhangLirong-amd and others added 3 commits May 6, 2026 13:57
…b + rotate_activation; 2. use topk_softplus fused kernel; 3. use mhc_pre in hc_head; 4. add scale_indexer_weights (#701)
Bundles seven related improvements landed during V4-Pro shakedown:

CI:
- Add DeepSeek-V4-Pro per-PR GSM8K entry (models_accuracy.json) with
  ATOM_USE_TRITON_MOE=1 — required, without it accuracy collapses
  from 0.95+ to 0.6. Threshold 0.92, baseline 0.96 from local
  4-run GSM8K-100 average.
- Add DeepSeek-V4-Pro to nightly benchmark matrix (models.json +
  workflow_dispatch toggle) at 1k/1k and 8k/1k across full
  concurrency sweep, same env_vars.

Engine:
- model_runner: KV alloc cross-check now includes per_req_cache
  tensor bytes; previously V4/GDN warned spuriously (~14% diff)
  because expected counted only the pool. Threshold loosened
  1% → 3% to absorb allocator alignment + state-init noise.
- scheduler: warn at submit time when a request can never be
  scheduled (input > max_num_batched_tokens, KV blocks > pool, or
  no per_req_cache slot). Previously the request sat in waiting
  forever blocking head-of-line with no log output.

deepseek_v4:
- Drop unused rotate_activation import (Indexer K-side rotate is a
  known TODO; only rope_rotate_activation from aiter is called).
  Unblocks PR 650 black + ruff F401 CI.

simple_inference:
- Add 1k/3k arithmetic stress prompts and bump default max-tokens
  256 → 300 to exercise long-prompt + scheduler-warn paths.
@valarLip valarLip marked this pull request as ready for review May 6, 2026 15:58
Copilot AI review requested due to automatic review settings May 6, 2026 15:58
@valarLip valarLip merged commit cd61e44 into main May 6, 2026
39 of 47 checks passed
@valarLip valarLip deleted the feat/deepseek-v4-pr1-skeleton branch May 6, 2026 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants